mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 23:48:57 +08:00
Compare commits
195 Commits
v0.4.2
...
optimize-p
| Author | SHA1 | Date | |
|---|---|---|---|
| d5bf492f16 | |||
| f775a07e30 | |||
| 4f0d17c05c | |||
| 10c38e3e46 | |||
| cafb8e06c5 | |||
| cbb2f59cc8 | |||
| 0ab278ca31 | |||
| 7a64d24aad | |||
| 8c7bab79f5 | |||
| dfbe60dc62 | |||
| a66cf40b20 | |||
| f790ad3c50 | |||
| ed59a7ed23 | |||
| 1936d7bab0 | |||
| 996cf2de5c | |||
| 044793d8df | |||
| c2d6d2f960 | |||
| 8279078e21 | |||
| b9c0605a8e | |||
| 37464a0f74 | |||
| c354072828 | |||
| f081c3ce4b | |||
| 260d119e86 | |||
| a360ff80bb | |||
| 1197e02141 | |||
| 657579113f | |||
| e9899fb7a4 | |||
| a377f0bd5e | |||
| e9d3aa04f6 | |||
| a22dea54d3 | |||
| 533c217792 | |||
| 6d21fa1cad | |||
| b35be5403f | |||
| 45a1a69b98 | |||
| 87a658c812 | |||
| 429d89720e | |||
| a9bcc7afb2 | |||
| d79d9eaaff | |||
| f758505c73 | |||
| d910816c73 | |||
| 87d41c849d | |||
| e07aff9e52 | |||
| 5bf185a1c4 | |||
| 4fbcb0f27e | |||
| 7c3604fb68 | |||
| b1c255630d | |||
| eb6c50cdc2 | |||
| eecd864388 | |||
| ae495c74ea | |||
| 4238bc82f2 | |||
| 594392d27a | |||
| 18c1f16d86 | |||
| 5bd3c65072 | |||
| 616e600e0b | |||
| dfba529b40 | |||
| 5ae5ed1e60 | |||
| 290f4ada2b | |||
| dd8de11f0a | |||
| 9ba415588a | |||
| d4f3985907 | |||
| 890aa93d27 | |||
| fbdb7b3ee2 | |||
| 1102bef219 | |||
| f17a1a8f96 | |||
| d5a1697772 | |||
| 325c119961 | |||
| 8e192ff967 | |||
| e64fde4b01 | |||
| 919770957f | |||
| 6a50f4cafa | |||
| e3470f8753 | |||
| a1242324c9 | |||
| 5eda2ea02a | |||
| 2ba80bed27 | |||
| 6066253296 | |||
| ee3eea0a1b | |||
| a36de682d4 | |||
| eb6d3c264d | |||
| 97b030005c | |||
| a3a73ab069 | |||
| 8674f9880e | |||
| c74c913bfb | |||
| 5f6d10c14c | |||
| 9b9a10d6cb | |||
| 99eff67ba9 | |||
| 14772eeb8e | |||
| 757b62c495 | |||
| e941f88584 | |||
| f12c3b5b3d | |||
| d130b573a0 | |||
| 65ae8c2c8f | |||
| c3af44722c | |||
| 1937e29848 | |||
| f0eecee610 | |||
| 943e72ca56 | |||
| 546a97ef69 | |||
| da5a0b539d | |||
| 6287537a0c | |||
| b57e6c5949 | |||
| 27ce85476e | |||
| f68470e803 | |||
| 2e9a2227ec | |||
| c0724fc915 | |||
| 86b45ae065 | |||
| c5711ef985 | |||
| 48d5985a08 | |||
| 33e0823de5 | |||
| 26148120b3 | |||
| 0150a10630 | |||
| 8e7fb5d43a | |||
| 9a31a817a8 | |||
| 2060e93659 | |||
| 8435b207af | |||
| 10fa9eea21 | |||
| e08188081b | |||
| b5853f9963 | |||
| f09edd8a25 | |||
| 6979ade384 | |||
| 9216b9cc38 | |||
| 5e0391c040 | |||
| dbc0754ddf | |||
| 99caa49106 | |||
| 5c342570d7 | |||
| 973617ae02 | |||
| 30e754390c | |||
| 52f8107cf2 | |||
| fc0d9dfc3a | |||
| 361c461a12 | |||
| a5675d348b | |||
| e9cdd2b1e2 | |||
| 65bf2ac165 | |||
| 8a7cc254a0 | |||
| 29bc01bf3b | |||
| 676a99982f | |||
| dc72402b57 | |||
| ccb63a8245 | |||
| c579b750a0 | |||
| 4bfa7e7f75 | |||
| ac1fbf7fd2 | |||
| 33d3914b1e | |||
| 1356df53bd | |||
| ce532ff45c | |||
| 8bc68e198c | |||
| 0fca3cdcf2 | |||
| e7c46b9527 | |||
| 350f9e107f | |||
| 702bee461f | |||
| a7be4d0072 | |||
| a709e87a4f | |||
| 6eaccb7353 | |||
| e254497b66 | |||
| 4e12131089 | |||
| fcc2994be6 | |||
| 2e7796f2cf | |||
| 706588a77d | |||
| 6a0f617210 | |||
| dac6a3f6ed | |||
| 64b77dfd7e | |||
| 51d4094fda | |||
| e965d46184 | |||
| 208b71bcc1 | |||
| c833101740 | |||
| 379da6dcb5 | |||
| ebce310b74 | |||
| be0c5180ac | |||
| cea64430f6 | |||
| a3c124570a | |||
| ff5abcd746 | |||
| 0ee535b294 | |||
| 190bc838e1 | |||
| f12b20decc | |||
| 16bc0a098f | |||
| e288df0632 | |||
| 8b9241be3a | |||
| f942efb5a3 | |||
| 89579a201f | |||
| 230c4b38c1 | |||
| 20cfcdec99 | |||
| ad932a221d | |||
| 5510cf0e8a | |||
| 0f9a6e3d22 | |||
| f6a593093a | |||
| d7740ea4dc | |||
| cc466a3290 | |||
| 8344f7742b | |||
| 469f85c782 | |||
| 10760da800 | |||
| 478aed5827 | |||
| 63575bc2e1 | |||
| a98187cf72 | |||
| bd99d22629 | |||
| 19cb4716ee | |||
| e186d37cb1 | |||
| 323f27b904 | |||
| 0650e5935b |
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
MAX_SIZE_MB = 100
|
||||
MAX_SIZE_MB = 200
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
|
||||
@ -1,10 +1,38 @@
|
||||
# This script build the ROCm docker image and runs test inside it.
|
||||
# This script runs test inside the corresponding ROCm docker container.
|
||||
set -ex
|
||||
|
||||
# Print ROCm version
|
||||
echo "--- ROCm info"
|
||||
rocminfo
|
||||
|
||||
# cleanup older docker images
|
||||
cleanup_docker() {
|
||||
# Get Docker's root directory
|
||||
docker_root=$(docker info -f '{{.DockerRootDir}}')
|
||||
if [ -z "$docker_root" ]; then
|
||||
echo "Failed to determine Docker root directory."
|
||||
exit 1
|
||||
fi
|
||||
echo "Docker root directory: $docker_root"
|
||||
# Check disk usage of the filesystem where Docker's root directory is located
|
||||
disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
|
||||
# Define the threshold
|
||||
threshold=70
|
||||
if [ "$disk_usage" -gt "$threshold" ]; then
|
||||
echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes
|
||||
docker volume prune -f
|
||||
echo "Docker images and volumes cleanup completed."
|
||||
else
|
||||
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||
fi
|
||||
}
|
||||
|
||||
# Call the cleanup docker function
|
||||
cleanup_docker
|
||||
|
||||
echo "--- Resetting GPUs"
|
||||
|
||||
echo "reset" > /opt/amdgpu/etc/gpu_state
|
||||
@ -19,15 +47,16 @@ done
|
||||
|
||||
echo "--- Building container"
|
||||
sha=$(git rev-parse --short HEAD)
|
||||
container_name=rocm_${sha}
|
||||
image_name=rocm_${sha}
|
||||
container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)
|
||||
docker build \
|
||||
-t ${container_name} \
|
||||
-t ${image_name} \
|
||||
-f Dockerfile.rocm \
|
||||
--progress plain \
|
||||
.
|
||||
|
||||
remove_docker_container() {
|
||||
docker rm -f ${container_name} || docker image rm -f ${container_name} || true
|
||||
docker rm -f ${container_name} || docker image rm -f ${image_name} || true
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
|
||||
@ -39,6 +68,6 @@ docker run \
|
||||
--rm \
|
||||
-e HF_TOKEN \
|
||||
--name ${container_name} \
|
||||
${container_name} \
|
||||
/bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//")
|
||||
${image_name} \
|
||||
/bin/bash -c "${@}"
|
||||
|
||||
|
||||
@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# run python-based benchmarks and upload the result to buildkite
|
||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
||||
python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
|
||||
bench_latency_exit_code=$?
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
|
||||
bench_throughput_exit_code=$?
|
||||
|
||||
# run server-based benchmarks and upload the result to buildkite
|
||||
@ -74,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then
|
||||
exit $bench_serving_exit_code
|
||||
fi
|
||||
|
||||
/workspace/buildkite-agent artifact upload openai-*.json
|
||||
rm ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
/workspace/buildkite-agent artifact upload "*.json"
|
||||
|
||||
@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py
|
||||
# Run the image
|
||||
docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
|
||||
|
||||
# offline inference
|
||||
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test bash -c "cd tests;
|
||||
pip install pytest Pillow protobuf
|
||||
bash ../.buildkite/download-images.sh
|
||||
cd ../
|
||||
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
|
||||
|
||||
@ -5,13 +5,16 @@
|
||||
|
||||
steps:
|
||||
- label: Regression Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s test_regression.py
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: AsyncEngine Test
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s async_engine
|
||||
|
||||
- label: Basic Correctness Test
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
@ -24,59 +27,68 @@ steps:
|
||||
command: pytest -v -s core
|
||||
|
||||
- label: Distributed Comm Ops Test
|
||||
command: pytest -v -s test_comm_ops.py
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s distributed/test_comm_ops.py
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
|
||||
- label: Distributed Tests
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
|
||||
num_gpus: 2 # only support 1 or 2 for now.
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- pytest -v -s test_pynccl_library.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
||||
|
||||
- label: Distributed Tests (Multiple Groups)
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
#mirror_hardwares: [amd]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
commands:
|
||||
- pytest -v -s test_pynccl.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
|
||||
- label: Engine Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
||||
|
||||
- label: Entrypoints Test
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
commands:
|
||||
# these tests have to be separated, because each one will allocate all posible GPU memory
|
||||
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
|
||||
- pytest -v -s entrypoints/test_server_oot_registration.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s entrypoints -m llm
|
||||
- pytest -v -s entrypoints -m openai
|
||||
|
||||
- label: Examples Test
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
# install aws cli for llava_example.py
|
||||
- pip install awscli
|
||||
# install tensorizer for tensorize_vllm_model.py
|
||||
- pip install awscli tensorizer
|
||||
- python3 offline_inference.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 llava_example.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
|
||||
- label: Kernels Test %N
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Models Test
|
||||
mirror_hardwares: [amd]
|
||||
#mirror_hardwares: [amd]
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
|
||||
- pytest -v -s models --ignore=models/test_llava.py
|
||||
|
||||
- label: Llava Test
|
||||
mirror_hardwares: [amd]
|
||||
@ -90,31 +102,53 @@ steps:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s samplers
|
||||
|
||||
- label: LogitsProcessor Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
|
||||
- label: Utils Test
|
||||
command: pytest -v -s test_utils.py
|
||||
|
||||
- label: Worker Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s worker
|
||||
|
||||
- label: Speculative decoding tests
|
||||
mirror_hardwares: [amd]
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s spec_decode
|
||||
|
||||
- label: LoRA Test %N
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
parallelism: 4
|
||||
|
||||
- label: LoRA Long Context (Distributed)
|
||||
#mirror_hardwares: [amd]
|
||||
num_gpus: 4
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
commands:
|
||||
# Temporarily run this way because we cannot clean up GPU mem usage
|
||||
# for multi GPU tests.
|
||||
# TODO(sang): Fix it.
|
||||
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
|
||||
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
|
||||
- pytest -v -s lora/test_long_context.py::test_self_consistency
|
||||
- pytest -v -s lora/test_long_context.py::test_quality
|
||||
- pytest -v -s lora/test_long_context.py::test_max_len
|
||||
|
||||
- label: Tensorizer Test
|
||||
#mirror_hardwares: [amd]
|
||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
||||
|
||||
- label: Metrics Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s metrics
|
||||
|
||||
- label: Quantization Test
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s quantization
|
||||
|
||||
- label: Benchmarks
|
||||
|
||||
59
.buildkite/test-template-aws.j2
Normal file
59
.buildkite/test-template-aws.j2
Normal file
@ -0,0 +1,59 @@
|
||||
{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %}
|
||||
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||
|
||||
steps:
|
||||
- label: ":docker: build image"
|
||||
agents:
|
||||
queue: cpu_queue
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||
- "docker push {{ docker_image }}"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
- exit_status: -10 # Agent was lost
|
||||
limit: 5
|
||||
- wait
|
||||
|
||||
{% for step in steps %}
|
||||
- label: "{{ step.label }}"
|
||||
agents:
|
||||
{% if step.no_gpu %}
|
||||
queue: cpu_queue
|
||||
{% elif step.num_gpus == 2 or step.num_gpus == 4 %}
|
||||
queue: gpu_4_queue
|
||||
{% else %}
|
||||
queue: gpu_1_queue
|
||||
{% endif %}
|
||||
soft_fail: true
|
||||
{% if step.parallelism %}
|
||||
parallelism: {{ step.parallelism }}
|
||||
{% endif %}
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
- exit_status: -10 # Agent was lost
|
||||
limit: 5
|
||||
plugins:
|
||||
- docker#v5.2.0:
|
||||
image: {{ docker_image }}
|
||||
always-pull: true
|
||||
propagate-environment: true
|
||||
{% if not step.no_gpu %}
|
||||
gpus: all
|
||||
{% endif %}
|
||||
command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}"]
|
||||
environment:
|
||||
- VLLM_USAGE_SOURCE=ci-test
|
||||
- HF_TOKEN
|
||||
{% if step.label == "Speculative decoding tests" %}
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
{% endif %}
|
||||
volumes:
|
||||
- /dev/shm:/dev/shm
|
||||
{% endfor %}
|
||||
@ -3,9 +3,8 @@
|
||||
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||
|
||||
steps:
|
||||
|
||||
- label: ":docker: build image"
|
||||
commands:
|
||||
commands:
|
||||
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||
- "docker push {{ docker_image }}"
|
||||
env:
|
||||
@ -14,6 +13,8 @@ steps:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
- exit_status: -10 # Agent was lost
|
||||
limit: 5
|
||||
- wait
|
||||
|
||||
- group: "AMD Tests"
|
||||
@ -24,7 +25,7 @@ steps:
|
||||
- label: "AMD: {{ step.label }}"
|
||||
agents:
|
||||
queue: amd
|
||||
command: bash .buildkite/run-amd-test.sh "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
|
||||
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
{% endif %}
|
||||
@ -39,6 +40,8 @@ steps:
|
||||
|
||||
- label: "Intel Test"
|
||||
depends_on: ~
|
||||
agents:
|
||||
queue: intel
|
||||
command: bash .buildkite/run-cpu-test.sh
|
||||
|
||||
{% for step in steps %}
|
||||
@ -53,6 +56,8 @@ steps:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
- exit_status: -10 # Agent was lost
|
||||
limit: 5
|
||||
plugins:
|
||||
- kubernetes:
|
||||
podSpec:
|
||||
|
||||
26
.clang-format
Normal file
26
.clang-format
Normal file
@ -0,0 +1,26 @@
|
||||
BasedOnStyle: Google
|
||||
UseTab: Never
|
||||
IndentWidth: 2
|
||||
ColumnLimit: 80
|
||||
|
||||
# Force pointers to the type for C++.
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
|
||||
# Reordering #include statements can (and currently will) introduce errors
|
||||
SortIncludes: false
|
||||
|
||||
# Style choices
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
IndentPPDirectives: BeforeHash
|
||||
|
||||
IncludeCategories:
|
||||
- Regex: '^<'
|
||||
Priority: 4
|
||||
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
|
||||
Priority: 3
|
||||
- Regex: '^"(qoda|\.\.)/'
|
||||
Priority: 2
|
||||
- Regex: '.*'
|
||||
Priority: 1
|
||||
2
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
@ -59,6 +59,8 @@ body:
|
||||
|
||||
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
||||
|
||||
Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues.
|
||||
|
||||
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
||||
placeholder: |
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
42
.github/workflows/clang-format.yml
vendored
Normal file
42
.github/workflows/clang-format.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
name: clang-format
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
clang-format:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install clang-format==18.1.5
|
||||
- name: Running clang-format
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||
'csrc/punica/bgmv/bgmv_config.h'
|
||||
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||
'csrc/punica/punica_ops.cu'
|
||||
'csrc/punica/type_convert.h'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
| xargs clang-format --dry-run --Werror
|
||||
1
.github/workflows/mypy.yaml
vendored
1
.github/workflows/mypy.yaml
vendored
@ -37,6 +37,7 @@ jobs:
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/multimodal --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
|
||||
3
.github/workflows/publish.yml
vendored
3
.github/workflows/publish.yml
vendored
@ -58,6 +58,9 @@ jobs:
|
||||
|
||||
- name: Setup ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
create-symlink: true
|
||||
key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
|
||||
|
||||
- name: Set up Linux Env
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
|
||||
@ -167,19 +167,47 @@ set(VLLM_EXT_SRC
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/moe_align_block_size_kernels.cu"
|
||||
"csrc/pybind.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(FetchContent)
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# CUTLASS 3.5.0
|
||||
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
|
||||
)
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/custom_all_reduce.cu")
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
|
||||
|
||||
#
|
||||
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
||||
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
|
||||
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||
set_source_files_properties(
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
|
||||
PROPERTIES
|
||||
COMPILE_FLAGS
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
define_gpu_extension_target(
|
||||
@ -189,6 +217,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
@ -219,7 +248,8 @@ set(VLLM_PUNICA_EXT_SRC
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||
"csrc/punica/punica_ops.cc")
|
||||
"csrc/punica/punica_ops.cu"
|
||||
"csrc/punica/punica_pybind.cpp")
|
||||
|
||||
#
|
||||
# Copy GPU compilation flags+update for punica
|
||||
@ -243,6 +273,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
|
||||
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
endif()
|
||||
|
||||
if (VLLM_PUNICA_GPU_ARCHES)
|
||||
@ -277,9 +310,7 @@ add_custom_target(default)
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
|
||||
|
||||
27
Dockerfile
27
Dockerfile
@ -79,31 +79,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
RUN python3 check-wheel-size.py dist
|
||||
|
||||
# the `vllm_nccl` package must be installed from source distribution
|
||||
# pip is too smart to store a wheel in the cache, and other CI jobs
|
||||
# will directly use the wheel from the cache, which is not what we want.
|
||||
# we need to remove it manually
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip cache remove vllm_nccl*
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
FROM dev as flash-attn-builder
|
||||
# max jobs used for build
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# flash attention version
|
||||
ARG flash_attn_version=v2.5.8
|
||||
ENV FLASH_ATTN_VERSION=${flash_attn_version}
|
||||
|
||||
WORKDIR /usr/src/flash-attention-v2
|
||||
|
||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
||||
--no-build-isolation --no-deps --no-cache-dir
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
|
||||
@ -122,10 +99,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
||||
|
||||
FROM ubuntu:22.04
|
||||
FROM ubuntu:22.04 AS cpu-test-1
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
|
||||
@ -9,6 +9,8 @@ RUN apt-get update -y \
|
||||
RUN pip install --upgrade pip \
|
||||
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
|
||||
|
||||
FROM cpu-test-1 AS build
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
@ -17,4 +19,8 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@ -92,16 +92,24 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
||||
WORKDIR /vllm-workspace
|
||||
COPY . .
|
||||
|
||||
#RUN python3 -m pip install pynvml # to be removed eventually
|
||||
RUN python3 -m pip install --upgrade pip numba
|
||||
|
||||
# make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
# Workaround for ray >= 2.10.0
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
|
||||
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -U -r requirements-rocm.txt \
|
||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
||||
&& cd ..
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
82
README.md
82
README.md
@ -14,6 +14,17 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
|
||||
|
||||
We are thrilled to announce our fourth vLLM Meetup!
|
||||
The vLLM team will share recent updates and roadmap.
|
||||
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
|
||||
Please register [here](https://lu.ma/agivllm) and join us!
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||
@ -51,41 +62,14 @@ vLLM is flexible and easy to use with:
|
||||
- (Experimental) Prefix caching support
|
||||
- (Experimental) Multi-lora support
|
||||
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||
- Transformer-like LLMs (e.g., Llama)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
||||
- Multi-modal LLMs (e.g., LLaVA)
|
||||
|
||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
|
||||
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
|
||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
|
||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
||||
- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||
- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
||||
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
|
||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
||||
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
|
||||
@ -93,9 +77,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
|
||||
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
|
||||
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
|
||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
||||
@ -105,6 +87,32 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
|
||||
We welcome and value any contributions and collaborations.
|
||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||
|
||||
## Sponsors
|
||||
|
||||
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
|
||||
|
||||
<!-- Note: Please sort them in alphabetical order. -->
|
||||
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
|
||||
|
||||
- a16z
|
||||
- AMD
|
||||
- Anyscale
|
||||
- AWS
|
||||
- Crusoe Cloud
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Dropbox
|
||||
- Lambda Lab
|
||||
- NVIDIA
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
- Trainy
|
||||
- UC Berkeley
|
||||
- UC San Diego
|
||||
|
||||
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||
|
||||
@ -89,6 +89,9 @@ async def async_request_tgi(
|
||||
output.latency = most_recent_timestamp - st
|
||||
output.success = True
|
||||
output.generated_text = data["generated_text"]
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
@ -276,6 +279,9 @@ async def async_request_openai_completions(
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import PromptStrictInputs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
|
||||
@ -18,6 +20,8 @@ def main(args: argparse.Namespace):
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(model=args.model,
|
||||
speculative_model=args.speculative_model,
|
||||
num_speculative_tokens=args.num_speculative_tokens,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
@ -28,9 +32,11 @@ def main(args: argparse.Namespace):
|
||||
quantization_param_path=args.quantization_param_path,
|
||||
device=args.device,
|
||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||
use_v2_block_manager=args.use_v2_block_manager,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
download_dir=args.download_dir,
|
||||
block_size=args.block_size)
|
||||
block_size=args.block_size,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
@ -44,7 +50,9 @@ def main(args: argparse.Namespace):
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
||||
dummy_inputs: List[PromptStrictInputs] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
@ -55,13 +63,13 @@ def main(args: argparse.Namespace):
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
llm.generate(dummy_inputs,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print(p.key_averages())
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
llm.generate(dummy_inputs,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
@ -93,12 +101,24 @@ def main(args: argparse.Namespace):
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f'{percentage}% percentile latency: {percentile} seconds')
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--speculative-model', type=str, default=None)
|
||||
parser.add_argument('--num-speculative-tokens', type=int, default=None)
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
@ -137,15 +157,13 @@ if __name__ == '__main__':
|
||||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8'],
|
||||
default='auto',
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||
default="auto",
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
@ -181,6 +199,7 @@ if __name__ == '__main__':
|
||||
action='store_true',
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument('--use-v2-block-manager', action='store_true')
|
||||
parser.add_argument(
|
||||
"--ray-workers-use-nsight",
|
||||
action='store_true',
|
||||
@ -191,5 +210,16 @@ if __name__ == '__main__':
|
||||
default=None,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the latency results in JSON format.')
|
||||
parser.add_argument('--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=0.9,
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -17,6 +17,10 @@ On the client side, run:
|
||||
--dataset-path <path to dataset> \
|
||||
--request-rate <request_rate> \ # By default <request_rate> is inf
|
||||
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
||||
|
||||
when using tgi backend, add
|
||||
--endpoint /generate_stream
|
||||
to the end of the command above.
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
@ -211,6 +215,11 @@ def calculate_metrics(
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
|
||||
if completed == 0:
|
||||
warnings.warn(
|
||||
"All requests failed. This is likely due to a misconfiguration "
|
||||
"on the benchmark arguments.",
|
||||
stacklevel=2)
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
@ -222,9 +231,9 @@ def calculate_metrics(
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(tpots) * 1000,
|
||||
median_tpot_ms=np.median(tpots) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@ -246,6 +255,24 @@ async def benchmark(
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
@ -242,6 +242,18 @@ def main(args: argparse.Namespace):
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
@ -311,15 +323,13 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
@ -353,6 +363,11 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
352
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
352
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
@ -0,0 +1,352 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.tensor) -> torch.tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.tensor) -> torch.tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
if dtype == torch.int8:
|
||||
return to_int8(a), to_int8(b)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return to_fp8(a), to_fp8(b)
|
||||
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
|
||||
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
return torch.mm(a, b)
|
||||
|
||||
|
||||
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
||||
scale_a: torch.tensor, scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
use_fast_accum=True)
|
||||
|
||||
|
||||
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
return ops.cutlass_scaled_mm_dq(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
|
||||
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
||||
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"a": a,
|
||||
"b": b,
|
||||
"scale_a": scale_a,
|
||||
"scale_b": scale_b,
|
||||
"out_dtype": out_dtype,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.int8
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
|
||||
timers = []
|
||||
# pytorch impl
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_i8_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||
"cutlass_i8_i8_bf16_scaled_mm"))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
|
||||
timers = []
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
pytorch_fp8_impl_fast_accum,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
pytorch_fp8_impl_fast_accum,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||
torch.float16, label, sub_label, cutlass_impl,
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||
raise ValueError("unsupported type")
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})")
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
# pickle all the results
|
||||
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_range_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||
n = len(dim_sizes)
|
||||
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KNs.append(KN)
|
||||
return KNs
|
||||
|
||||
model_bench_data = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
Ms = args.batch_sizes
|
||||
KNs = model_shapes(model, tp_size)
|
||||
MKNs = []
|
||||
for m in Ms:
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
# Print all results
|
||||
for data, model_tp in zip(model_bench_data, models_tps):
|
||||
model, tp_size = model_tp
|
||||
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||
print_timers(data)
|
||||
|
||||
timestamp = int(time.time())
|
||||
|
||||
all_data = []
|
||||
for d in model_bench_data:
|
||||
all_data.extend(d)
|
||||
# pickle all data
|
||||
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
return torch.int8
|
||||
if dt == "fp8":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
Benchmark Cutlass GEMM.
|
||||
|
||||
To run square GEMMs:
|
||||
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||
|
||||
To run constant N and K and sweep M:
|
||||
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||
|
||||
To run dimensions from a model:
|
||||
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
parser.add_argument("--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
square_parser.set_defaults(func=run_square_bench)
|
||||
|
||||
range_parser = subparsers.add_parser("range_bench")
|
||||
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument("--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys())
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
37
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
37
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||
# - TP1 : K = 14336, N = 4096
|
||||
# - TP2 : K = 7168, N = 4096
|
||||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||
# - TP1 : K = 4096, N = 6144
|
||||
# - TP4 : K = 4096, N = 1536
|
||||
|
||||
# TP1 shapes
|
||||
WEIGHT_SHAPES = {
|
||||
"mistralai/Mistral-7B-v0.1": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf": [
|
||||
([4096, 12288], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 22016], 1),
|
||||
([11008, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf": [
|
||||
([5120, 15360], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 27648], 1),
|
||||
([13824, 5120], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
}
|
||||
233
benchmarks/kernels/benchmark_marlin.py
Normal file
233
benchmarks/kernels/benchmark_marlin.py
Normal file
@ -0,0 +1,233 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
|
||||
size_m, size_k, size_n):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
|
||||
group_size, size_m, size_k, size_n))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||
|
||||
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
|
||||
|
||||
# Marlin quant
|
||||
(
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
marlin_rand_perm,
|
||||
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||
|
||||
# Marlin_24 quant
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx,
|
||||
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
if act_order:
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
"num_bits": num_bits,
|
||||
"group_size": group_size,
|
||||
"size_m": size_m,
|
||||
"size_n": size_n,
|
||||
"size_k": size_k,
|
||||
"a": a,
|
||||
"a_tmp": a_tmp,
|
||||
# Marlin params
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_rand_perm": marlin_rand_perm,
|
||||
"marlin_workspace": marlin_workspace,
|
||||
"is_k_full": is_k_full,
|
||||
# Marlin_24 params
|
||||
"marlin_24_w_ref": marlin_24_w_ref,
|
||||
"marlin_24_q_w_comp": marlin_24_q_w_comp,
|
||||
"marlin_24_meta": marlin_24_meta,
|
||||
"marlin_24_s": marlin_24_s,
|
||||
"marlin_24_workspace": marlin_24_workspace,
|
||||
# GPTQ params
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||
}
|
||||
|
||||
min_run_time = 1
|
||||
|
||||
# Warmup pytorch
|
||||
for i in range(5):
|
||||
torch.matmul(a, marlin_w_ref)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="torch.matmul(a, marlin_w_ref)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="pytorch_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_24_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results = []
|
||||
|
||||
for model in args.models:
|
||||
for layer in WEIGHT_SHAPES[model]:
|
||||
size_k = layer[0]
|
||||
size_n = layer[1]
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for act_order in ACT_ORDER_OPTS:
|
||||
if len(args.limit_act_order
|
||||
) > 0 and act_order not in args.limit_act_order:
|
||||
continue
|
||||
|
||||
for is_k_full in K_FULL_OPTS:
|
||||
if len(args.limit_k_full
|
||||
) > 0 and is_k_full not in args.limit_k_full:
|
||||
continue
|
||||
|
||||
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
if len(args.limit_num_bits
|
||||
) > 0 and num_bits not in args.limit_num_bits:
|
||||
continue
|
||||
|
||||
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
if len(
|
||||
args.limit_group_size
|
||||
) > 0 and group_size not in args.limit_group_size:
|
||||
continue
|
||||
|
||||
# For act_order, the group_size must be less than
|
||||
# size_k
|
||||
if act_order and (group_size == size_k
|
||||
or group_size == -1):
|
||||
continue
|
||||
|
||||
for size_m in args.batch_sizes:
|
||||
bench_run(results, model, act_order, is_k_full,
|
||||
num_bits, group_size, size_m, size_k,
|
||||
size_n)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
# For quick benchmarking use:
|
||||
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -11,25 +11,36 @@ from tqdm import tqdm
|
||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
||||
get_config_file_name)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
|
||||
def main(dtype: str):
|
||||
def main(model, tp_size, gpu, dtype: str):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
||||
method = fused_moe
|
||||
for bs in [
|
||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||
2048, 3072, 4096
|
||||
]:
|
||||
run_grid(bs, method=method, dtype=dtype)
|
||||
run_grid(bs,
|
||||
model=model,
|
||||
method=method,
|
||||
gpu=gpu,
|
||||
tp_size=tp_size,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def run_grid(bs, method, dtype: str):
|
||||
d_model = 4096
|
||||
def run_grid(bs, model, method, gpu, tp_size, dtype: str):
|
||||
if model == '8x7B':
|
||||
d_model = 4096
|
||||
model_intermediate_size = 14336
|
||||
num_layers = 32
|
||||
elif model == '8x22B':
|
||||
d_model = 6144
|
||||
model_intermediate_size = 16384
|
||||
num_layers = 56
|
||||
else:
|
||||
raise ValueError(f'Unsupported Mixtral model {model}')
|
||||
num_total_experts = 8
|
||||
top_k = 2
|
||||
tp_size = 2
|
||||
model_intermediate_size = 14336
|
||||
num_layers = 32
|
||||
# tp_size = 2
|
||||
num_calls = 100
|
||||
|
||||
num_warmup_trials = 1
|
||||
@ -211,5 +222,18 @@ if __name__ == "__main__":
|
||||
choices=['float8', 'float16'],
|
||||
help='Data type used for fused_moe kernel computations',
|
||||
)
|
||||
parser.add_argument('--model',
|
||||
type=str,
|
||||
default='8x7B',
|
||||
choices=['8x7B', '8x22B'],
|
||||
help='The Mixtral model to benchmark')
|
||||
parser.add_argument('--tp-size',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Tensor paralleli size')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU ID for benchmarking")
|
||||
args = parser.parse_args()
|
||||
sys.exit(main(args.dtype))
|
||||
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))
|
||||
|
||||
@ -170,7 +170,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
@ -183,13 +183,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
help="Data type for kv cache storage. If 'auto', will use model "
|
||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--num-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||
parser.add_argument("--dtype",
|
||||
|
||||
75
benchmarks/kernels/benchmark_shapes.py
Normal file
75
benchmarks/kernels/benchmark_shapes.py
Normal file
@ -0,0 +1,75 @@
|
||||
WEIGHT_SHAPES = {
|
||||
"ideal": [[4 * 256 * 32, 256 * 32]],
|
||||
"mistralai/Mistral-7B-v0.1/TP1": [
|
||||
[4096, 6144],
|
||||
[4096, 4096],
|
||||
[4096, 28672],
|
||||
[14336, 4096],
|
||||
],
|
||||
"mistralai/Mistral-7B-v0.1/TP2": [
|
||||
[4096, 3072],
|
||||
[2048, 4096],
|
||||
[4096, 14336],
|
||||
[7168, 4096],
|
||||
],
|
||||
"mistralai/Mistral-7B-v0.1/TP4": [
|
||||
[4096, 1536],
|
||||
[1024, 4096],
|
||||
[4096, 7168],
|
||||
[3584, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP1": [
|
||||
[4096, 12288],
|
||||
[4096, 4096],
|
||||
[4096, 22016],
|
||||
[11008, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP2": [
|
||||
[4096, 6144],
|
||||
[2048, 4096],
|
||||
[4096, 11008],
|
||||
[5504, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP4": [
|
||||
[4096, 3072],
|
||||
[1024, 4096],
|
||||
[4096, 5504],
|
||||
[2752, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP1": [
|
||||
[5120, 15360],
|
||||
[5120, 5120],
|
||||
[5120, 27648],
|
||||
[13824, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP2": [
|
||||
[5120, 7680],
|
||||
[2560, 5120],
|
||||
[5120, 13824],
|
||||
[6912, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP4": [
|
||||
[5120, 3840],
|
||||
[1280, 5120],
|
||||
[5120, 6912],
|
||||
[3456, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP1": [
|
||||
[8192, 10240],
|
||||
[8192, 8192],
|
||||
[8192, 57344],
|
||||
[28672, 8192],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP2": [
|
||||
[8192, 5120],
|
||||
[4096, 8192],
|
||||
[8192, 28672],
|
||||
[14336, 8192],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP4": [
|
||||
[8192, 2560],
|
||||
[2048, 8192],
|
||||
[8192, 14336],
|
||||
[7168, 8192],
|
||||
],
|
||||
}
|
||||
@ -4,7 +4,7 @@ PORT=8000
|
||||
MODEL=$1
|
||||
TOKENS=$2
|
||||
|
||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
||||
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
|
||||
-v $PWD/data:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||
--model-id $MODEL \
|
||||
|
||||
63
benchmarks/overheads/benchmark_hashing.py
Normal file
63
benchmarks/overheads/benchmark_hashing.py
Normal file
@ -0,0 +1,63 @@
|
||||
import argparse
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# A very long prompt, total number of tokens is about 15k.
|
||||
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
|
||||
] * 1000
|
||||
LONG_PROMPT = ' '.join(LONG_PROMPT)
|
||||
|
||||
|
||||
def main(args):
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
use_v2_block_manager=args.use_v2_block_manager,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
profiler = cProfile.Profile()
|
||||
|
||||
print("------warm up------")
|
||||
for i in range(3):
|
||||
output = llm.generate(LONG_PROMPT, sampling_params)
|
||||
print(output[0].outputs[0].text)
|
||||
|
||||
print("------start generating------")
|
||||
for i in range(3):
|
||||
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
|
||||
globals(), locals())
|
||||
|
||||
# analyze the runtime of hashing function
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.sort_stats('cumulative')
|
||||
total_time = 0
|
||||
total_calls = 0
|
||||
for func in stats.stats:
|
||||
if 'hash_of_block' in func[2]:
|
||||
total_time = stats.stats[func][3]
|
||||
total_calls = stats.stats[func][0]
|
||||
percentage = (total_time / stats.total_tt) * 100
|
||||
print(f"Hashing took {total_time:.2f} seconds,"
|
||||
f"{percentage:.2f}% of the total runtime.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the performance of hashing function in'
|
||||
'automatic prefix caching.')
|
||||
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='enable prefix caching')
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
"Failed to determine torch nvcc compiler flags")
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
||||
list(APPEND GPU_FLAGS "-DENABLE_FP8")
|
||||
endif()
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
list(REMOVE_ITEM GPU_FLAGS
|
||||
@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
|
||||
list(APPEND GPU_FLAGS
|
||||
"-DUSE_ROCM"
|
||||
"-DENABLE_FP8_E4M3"
|
||||
"-DENABLE_FP8"
|
||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||
"-U__HIP_NO_HALF_OPERATORS__"
|
||||
"-fno-gpu-rdc")
|
||||
|
||||
@ -10,11 +10,11 @@
|
||||
namespace vllm {
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
||||
const float f = (float) x;
|
||||
const float f = (float)x;
|
||||
constexpr float ALPHA = M_SQRT1_2;
|
||||
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
||||
const float f = (float) x;
|
||||
const float f = (float)x;
|
||||
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
||||
constexpr float KAPPA = 0.044715;
|
||||
float x_cube = f * f * f;
|
||||
float inner = BETA * (f + KAPPA * x_cube);
|
||||
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
|
||||
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation and gating kernel.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"act_and_mul_kernel", \
|
||||
[&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
});
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||
}
|
||||
@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
@ -108,54 +102,49 @@ __global__ void activation_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
// Launch element-wise activation kernel.
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
int d = input.size(-1); \
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"activation_kernel", \
|
||||
[&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
int d = input.size(-1); \
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
});
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||
const float x3 = (float) (x * x * x);
|
||||
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
const float x3 = (float)(x * x * x);
|
||||
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
|
||||
return ((T)0.5) * x * (((T)1.0) + t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||
const float f = (float) x;
|
||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
const float f = (float)x;
|
||||
const T t =
|
||||
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
|
||||
return ((T)0.5) * x * (((T)1.0) + t);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
void gelu_new(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||
}
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
void gelu_fast(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -22,31 +23,31 @@
|
||||
namespace vllm {
|
||||
|
||||
// A vector type to store Q, K, V elements.
|
||||
template<typename T, int VEC_SIZE>
|
||||
template <typename T, int VEC_SIZE>
|
||||
struct Vec {};
|
||||
|
||||
// A vector type to store FP32 accumulators.
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
struct FloatVec {};
|
||||
|
||||
// Template vector operations.
|
||||
template<typename Acc, typename A, typename B>
|
||||
template <typename Acc, typename A, typename B>
|
||||
inline __device__ Acc mul(A a, B b);
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline __device__ float sum(T v);
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline __device__ float dot(T a, T b) {
|
||||
return sum(mul<T, T, T>(a, b));
|
||||
}
|
||||
|
||||
template<typename A, typename T>
|
||||
template <typename A, typename T>
|
||||
inline __device__ float dot(T a, T b) {
|
||||
return sum(mul<A, T, T>(a, b));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
inline __device__ void zero(T& dst) {
|
||||
constexpr int WORDS = sizeof(T) / 4;
|
||||
union {
|
||||
@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
|
||||
dst = tmp.raw;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -26,7 +27,7 @@
|
||||
namespace vllm {
|
||||
|
||||
// Q*K^T operation.
|
||||
template<int THREAD_GROUP_SIZE, typename Vec, int N>
|
||||
template <int THREAD_GROUP_SIZE, typename Vec, int N>
|
||||
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
using A_vec = typename FloatVec<Vec>::Type;
|
||||
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||
@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
return qk;
|
||||
}
|
||||
|
||||
template<typename T, int THREAD_GROUP_SIZE>
|
||||
template <typename T, int THREAD_GROUP_SIZE>
|
||||
struct Qk_dot {
|
||||
template<typename Vec, int N>
|
||||
template <typename Vec, int N>
|
||||
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -28,8 +30,8 @@
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
@ -50,37 +52,37 @@ struct bf16_8_t {
|
||||
};
|
||||
|
||||
// BF16 vector types for Q, K, V.
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<__nv_bfloat16, 1> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<__nv_bfloat16, 2> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<__nv_bfloat16, 4> {
|
||||
using Type = bf16_4_t;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<__nv_bfloat16, 8> {
|
||||
using Type = bf16_8_t;
|
||||
};
|
||||
|
||||
// FP32 accumulator vector types corresponding to Vec.
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<__nv_bfloat16> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<__nv_bfloat162> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<bf16_4_t> {
|
||||
using Type = Float4_;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<bf16_8_t> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
assert(false);
|
||||
#else
|
||||
#ifndef USE_ROCM
|
||||
return a + b;
|
||||
return a + b;
|
||||
#else
|
||||
return __hadd(a, b);
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
||||
}
|
||||
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
||||
bf16_4_t c;
|
||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
bf16_4_t c;
|
||||
@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
||||
bf16_8_t c;
|
||||
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
bf16_8_t c;
|
||||
@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
float fa = __bfloat162float(a);
|
||||
float fb = __bfloat162float(b);
|
||||
return fa * fb;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
float2 fa = bf1622float2(a);
|
||||
float2 fb = bf1622float2(b);
|
||||
return mul<float2, float2, float2>(fa, fb);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
Float4_ fc;
|
||||
@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||
@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
__nv_bfloat162 s = bf162bf162(a);
|
||||
Float8_ fc;
|
||||
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||
__nv_bfloat162 c) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||
__nv_bfloat162 c) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
||||
}
|
||||
|
||||
// Vector sum.
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(__nv_bfloat16 v) {
|
||||
return __bfloat162float(v);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(__nv_bfloat162 v) {
|
||||
float2 vf = bf1622float2(v);
|
||||
return vf.x + vf.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(bf16_4_t v) {
|
||||
return sum(v.x) + sum(v.y);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(bf16_8_t v) {
|
||||
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
||||
}
|
||||
@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -30,37 +32,37 @@
|
||||
namespace vllm {
|
||||
|
||||
// FP16 vector types for Q, K, V.
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint16_t, 1> {
|
||||
using Type = uint16_t;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint16_t, 2> {
|
||||
using Type = uint32_t;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint16_t, 4> {
|
||||
using Type = uint2;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint16_t, 8> {
|
||||
using Type = uint4;
|
||||
};
|
||||
|
||||
// FP32 accumulator vector types corresponding to Vec.
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<uint16_t> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<uint32_t> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<uint2> {
|
||||
using Type = Float4_;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<uint4> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||
return b;
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u16[0] = a;
|
||||
tmp.u16[1] = a;
|
||||
@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
||||
: "=r"(tmp.u32)
|
||||
: "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
#endif
|
||||
#else
|
||||
tmp.u16[0] = float_to_half(f.x);
|
||||
@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||
}
|
||||
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
||||
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
||||
uint2 c;
|
||||
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||
@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
uint2 c;
|
||||
@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
||||
uint4 c;
|
||||
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||
@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
uint4 c;
|
||||
@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float mul(uint16_t a, uint16_t b) {
|
||||
float fa = half_to_float(a);
|
||||
float fb = half_to_float(b);
|
||||
return fa * fb;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
||||
float2 fa = half2_to_float2(a);
|
||||
float2 fb = half2_to_float2(b);
|
||||
return mul<float2, float2, float2>(fa, fb);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
||||
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||
@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
Float4_ fc;
|
||||
@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||
@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
Float8_ fc;
|
||||
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(d)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
||||
: "=v"(d)
|
||||
: "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
||||
}
|
||||
|
||||
// Vector sum.
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(uint16_t v) {
|
||||
return half_to_float(v);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(uint32_t v) {
|
||||
float2 tmp = half2_to_float2(v);
|
||||
return tmp.x + tmp.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(uint2 v) {
|
||||
uint32_t c = add(v.x, v.y);
|
||||
return sum(c);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(uint4 v) {
|
||||
uint32_t c = add(v.x, v.y);
|
||||
c = add(c, v.z);
|
||||
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
|
||||
}
|
||||
|
||||
// From float16 to float32.
|
||||
inline __device__ float to_float(uint16_t u) {
|
||||
return half_to_float(u);
|
||||
}
|
||||
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
||||
|
||||
inline __device__ float2 to_float(uint32_t u) {
|
||||
return half2_to_float2(u);
|
||||
}
|
||||
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
||||
|
||||
inline __device__ Float4_ to_float(uint2 u) {
|
||||
Float4_ tmp;
|
||||
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -38,37 +40,35 @@ struct Float8_ {
|
||||
};
|
||||
|
||||
// FP32 vector types for Q, K, V.
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<float, 1> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<float, 2> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<float, 4> {
|
||||
using Type = float4;
|
||||
};
|
||||
|
||||
// FP32 accumulator vector types corresponding to Vec.
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<float> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<float2> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
template <>
|
||||
struct FloatVec<float4> {
|
||||
using Type = float4;
|
||||
};
|
||||
|
||||
// Vector addition.
|
||||
inline __device__ float add(float a, float b) {
|
||||
return a + b;
|
||||
}
|
||||
inline __device__ float add(float a, float b) { return a + b; }
|
||||
|
||||
inline __device__ float2 add(float2 a, float2 b) {
|
||||
float2 c;
|
||||
@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
|
||||
}
|
||||
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float mul<float, float>(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float2 mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x;
|
||||
@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float2 mul(float a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a * b.x;
|
||||
@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float4 mul(float4 a, float4 b) {
|
||||
float4 c;
|
||||
c.x = a.x * b.x;
|
||||
@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float4 mul(float a, float4 b) {
|
||||
float4 c;
|
||||
c.x = a * b.x;
|
||||
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ float fma(float a, float b, float c) {
|
||||
return a * b + c;
|
||||
}
|
||||
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
||||
|
||||
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
||||
float2 d;
|
||||
@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
||||
}
|
||||
|
||||
// Vector sum.
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(float v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(float2 v) {
|
||||
return v.x + v.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(float4 v) {
|
||||
return v.x + v.y + v.z + v.w;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(Float4_ v) {
|
||||
return v.x.x + v.x.y + v.y.x + v.y.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline __device__ float sum(Float8_ v) {
|
||||
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
||||
}
|
||||
|
||||
// Vector dot product.
|
||||
inline __device__ float dot(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
inline __device__ float dot(float a, float b) { return a * b; }
|
||||
|
||||
inline __device__ float dot(float2 a, float2 b) {
|
||||
float2 c = mul<float2, float2, float2>(a, b);
|
||||
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
|
||||
}
|
||||
|
||||
// From float to float.
|
||||
inline __device__ void from_float(float& dst, float src) {
|
||||
dst = src;
|
||||
}
|
||||
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
||||
|
||||
inline __device__ void from_float(float2& dst, float2 src) {
|
||||
dst = src;
|
||||
}
|
||||
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
||||
|
||||
inline __device__ void from_float(float4& dst, float4 src) {
|
||||
dst = src;
|
||||
}
|
||||
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
||||
|
||||
// From float to float.
|
||||
inline __device__ float to_float(float u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float to_float(float u) { return u; }
|
||||
|
||||
inline __device__ float2 to_float(float2 u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float2 to_float(float2 u) { return u; }
|
||||
|
||||
inline __device__ float4 to_float(float4 u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float4 to_float(float4 u) { return u; }
|
||||
|
||||
inline __device__ Float4_ to_float(Float4_ u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
||||
|
||||
inline __device__ Float8_ to_float(Float8_ u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(float& dst) {
|
||||
dst = 0.f;
|
||||
}
|
||||
inline __device__ void zero(float& dst) { dst = 0.f; }
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
@ -3,33 +3,39 @@
|
||||
#include "attention_generic.cuh"
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_fp8.h>
|
||||
#endif // USE_ROCM
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
namespace vllm {
|
||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||
|
||||
enum class Fp8KVCacheDataType {
|
||||
kAuto = 0,
|
||||
kFp8E4M3 = 1,
|
||||
kFp8E5M2 = 2,
|
||||
};
|
||||
|
||||
// fp8 vector types for quantization of kv cache
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint8_t, 1> {
|
||||
using Type = uint8_t;
|
||||
using Type = uint8_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint8_t, 2> {
|
||||
using Type = uint16_t;
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint8_t, 4> {
|
||||
using Type = uint32_t;
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct Vec<uint8_t, 8> {
|
||||
using Type = uint2;
|
||||
using Type = uint2;
|
||||
};
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
42
csrc/cache.h
42
csrc/cache.h
@ -5,34 +5,24 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale);
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, const float kv_scale);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache);
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const float scale, const std::string& kv_cache_dtype);
|
||||
|
||||
@ -4,10 +4,11 @@
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
#include "quantization/fp8/amd_detail/quant_utils.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include "quantization/fp8/amd/quant_utils.cuh"
|
||||
#else
|
||||
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
@ -17,20 +18,17 @@
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping) {
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
cudaMemcpyKind memcpy_type;
|
||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||
memcpy_type = cudaMemcpyDeviceToHost;
|
||||
@ -40,41 +38,44 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
// NOTE(youkaichao): keep in mind that `block_mapping` should be
|
||||
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
|
||||
// synchronization.
|
||||
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||
|
||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
src_device.is_cuda() ? src_device : dst_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
int64_t dst_block_number = pair.second;
|
||||
const int64_t num_blocks = block_mapping.size(0);
|
||||
for (size_t i = 0; i < num_blocks; i++) {
|
||||
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
cudaMemcpyAsync(
|
||||
dst_ptr + dst_offset,
|
||||
src_ptr + src_offset,
|
||||
block_size_in_bytes,
|
||||
memcpy_type,
|
||||
stream);
|
||||
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
||||
block_size_in_bytes, memcpy_type, stream);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
template<typename scalar_t>
|
||||
__global__ void copy_blocks_kernel(
|
||||
int64_t* key_cache_ptrs,
|
||||
int64_t* value_cache_ptrs,
|
||||
const int64_t* __restrict__ block_mapping,
|
||||
const int numel_per_block) {
|
||||
template <typename scalar_t>
|
||||
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
||||
int64_t* value_cache_ptrs,
|
||||
const int64_t* __restrict__ block_mapping,
|
||||
const int numel_per_block) {
|
||||
const int layer_idx = blockIdx.x;
|
||||
const int pair_idx = blockIdx.y;
|
||||
|
||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache =
|
||||
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
|
||||
@ -92,12 +93,11 @@ __global__ void copy_blocks_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
@ -111,29 +111,23 @@ void copy_blocks(
|
||||
int64_t key_cache_ptrs[num_layers];
|
||||
int64_t value_cache_ptrs[num_layers];
|
||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
key_cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
// Create block mapping array.
|
||||
std::vector<int64_t> block_mapping_vec;
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
for (int64_t dst_block_number : pair.second) {
|
||||
block_mapping_vec.push_back(src_block_number);
|
||||
block_mapping_vec.push_back(dst_block_number);
|
||||
}
|
||||
}
|
||||
int64_t* block_mapping_array = block_mapping_vec.data();
|
||||
int num_pairs = block_mapping_vec.size() / 2;
|
||||
|
||||
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
||||
int num_pairs = block_mapping.size(0);
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
// NOTE: This synchronizes the CPU and GPU.
|
||||
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor key_cache_ptrs_tensor =
|
||||
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor =
|
||||
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
@ -142,31 +136,28 @@ void copy_blocks(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping_tensor.data_ptr<int64_t>(),
|
||||
numel_per_block);
|
||||
}));
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x,
|
||||
const float kv_scale) {
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
||||
// block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
||||
// block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x,
|
||||
const float kv_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
@ -187,47 +178,39 @@ __global__ void reshape_and_cache_kernel(
|
||||
const int x_idx = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
const int64_t tgt_key_idx =
|
||||
block_idx * num_heads * (head_size / x) * block_size * x +
|
||||
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
||||
block_offset * x + x_offset;
|
||||
const int64_t tgt_value_idx =
|
||||
block_idx * num_heads * head_size * block_size +
|
||||
head_idx * head_size * block_size + head_offset * block_size +
|
||||
block_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (is_fp8_kv_cache) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
|
||||
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
key_cache[tgt_key_idx] = tgt_key;
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||
value_cache[tgt_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride,
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size) {
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -242,40 +225,37 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride
|
||||
+ block_offset * num_heads * head_size
|
||||
+ head_idx * head_size
|
||||
+ head_offset;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x, \
|
||||
kv_scale);
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||
num_heads, head_size, block_size, x, kv_scale);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale)
|
||||
{
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype, const float kv_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
@ -289,35 +269,18 @@ void reshape_and_cache(
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype) {
|
||||
// FIXME: only support auto datatype, does not support fp8
|
||||
if (kv_cache_dtype != "auto") {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
@ -337,63 +300,47 @@ void reshape_and_cache_flash(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_flash",
|
||||
[&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(),
|
||||
v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(),
|
||||
block_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size);
|
||||
});
|
||||
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||
value_stride, num_heads, head_size, block_size);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__global__ void convert_fp8_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const int64_t block_stride) {
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const float kv_scale,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
dst_cache[idx] =
|
||||
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_CONVERT_FP8(Tout, Tin) \
|
||||
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
block_stride);
|
||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
||||
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache)
|
||||
{
|
||||
// Only for testing.
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const float kv_scale, const std::string& kv_cache_dtype) {
|
||||
torch::Device src_device = src_cache.device();
|
||||
torch::Device dst_device = dst_cache.device();
|
||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||
|
||||
int64_t num_blocks = src_cache.size(0);
|
||||
@ -403,17 +350,37 @@ void convert_fp8(
|
||||
dim3 block(std::min(block_stride, int64_t(512)));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(uint8_t, float);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint8_t, uint16_t);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(float, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint16_t, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
|
||||
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
|
||||
bool is_gated>
|
||||
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
|
||||
scalar_t *__restrict__ output) {
|
||||
void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ output) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
|
||||
@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
|
||||
}
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
|
||||
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 zeros(0.0);
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
return x / (ones + (zeros - x).exp());
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||
const vec_op::FP32Vec8 w2(0.044715f);
|
||||
@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
|
||||
return w3 * x * (ones + t);
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||
const vec_op::FP32Vec8 w2(0.044715f);
|
||||
@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
|
||||
return w3 * x * (ones + t);
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(M_SQRT1_2);
|
||||
const vec_op::FP32Vec8 w2(0.5);
|
||||
return x * w2 * (ones + (x * w1).er());
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
|
||||
const vec_op::FP32Vec8 w2(0.5);
|
||||
@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
|
||||
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
|
||||
return x * w2 * (ones + inner.tanh());
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
|
||||
input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||
activation_kernel<scalar_t, silu_act, true>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_and_mul(torch::Tensor &out, // [..., d]
|
||||
torch::Tensor &input) // [..., 2 * d]
|
||||
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
|
||||
input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||
activation_kernel<scalar_t, gelu_act, true>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
|
||||
torch::Tensor &input) // [..., 2 * d]
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
|
||||
void gelu_new(torch::Tensor& out, torch::Tensor& input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1);
|
||||
|
||||
@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) {
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
|
||||
void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1);
|
||||
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t> struct KernelVecType {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using q_load_vec_type = void;
|
||||
using q_vec_type = void;
|
||||
using k_load_vec_type = void;
|
||||
@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <> struct KernelVecType<float> {
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using q_load_vec_type = vec_op::FP32Vec4;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <> struct KernelVecType<c10::BFloat16> {
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::BF16Vec32;
|
||||
using k_load_vec_type = vec_op::BF16Vec32;
|
||||
@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#else
|
||||
template <> struct KernelVecType<c10::BFloat16> {
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::BF16Vec16;
|
||||
@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> {
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
|
||||
const int capacity) {
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T>
|
||||
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
||||
const float alibi_slope, const int start_index,
|
||||
const int seq_len) {
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
|
||||
const int capacity,
|
||||
const float alibi_slope,
|
||||
const int start_index,
|
||||
const int seq_len) {
|
||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
|
||||
FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
|
||||
const int size) {
|
||||
T max = max_data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
@ -132,9 +137,9 @@ struct reduceQKBlockKernel {
|
||||
static_assert(k_load_vec_type::get_elem_num() % x == 0);
|
||||
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
|
||||
|
||||
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
|
||||
const scalar_t *__restrict__ k_block,
|
||||
float *__restrict__ logits, float scale,
|
||||
FORCE_INLINE static void call(const scalar_t* __restrict__ q,
|
||||
const scalar_t* __restrict__ k_block,
|
||||
float* __restrict__ logits, float scale,
|
||||
const int token_num) {
|
||||
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
|
||||
|
||||
@ -196,8 +201,8 @@ struct reduceQKBlockKernel {
|
||||
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||
int HEAD_PARTITION_SIZE, typename acc_t>
|
||||
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
||||
acc_t &&acc) {
|
||||
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
|
||||
acc_t&& acc) {
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
|
||||
static_assert(BLOCK_SIZE == ELEM_NUM);
|
||||
@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
||||
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
|
||||
});
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
// Paged attention v1
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||
struct paged_attention_v1_impl {
|
||||
static void
|
||||
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
static void call(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int *__restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads) {
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
|
||||
@ -243,32 +248,31 @@ struct paged_attention_v1_impl {
|
||||
|
||||
size_t logits_bytes =
|
||||
parallel_work_item_num * max_seq_len_padded * sizeof(float);
|
||||
float *logits = (float *)std::aligned_alloc(
|
||||
64, logits_bytes); // Cacheline alignment for each context token.
|
||||
// [parallel_work_item_num, max_seq_len_padded]
|
||||
float* logits = (float*)std::aligned_alloc(
|
||||
64, logits_bytes); // Cacheline alignment for each context token.
|
||||
// [parallel_work_item_num, max_seq_len_padded]
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
int seq_len = seq_lens[seq_idx];
|
||||
const int *seq_block_table =
|
||||
const int* seq_block_table =
|
||||
block_tables + max_num_blocks_per_seq * seq_idx;
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t *__restrict__ q_vec_ptr =
|
||||
const scalar_t* __restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
const int last_block_token_num =
|
||||
seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
float *__restrict__ thread_block_logits =
|
||||
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
float* __restrict__ thread_block_logits =
|
||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t *__restrict__ k_block_cache_ptr =
|
||||
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float *__restrict__ head_block_logits =
|
||||
float* __restrict__ head_block_logits =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
|
||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||
seq_len);
|
||||
} else {
|
||||
reduceSoftmax(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE);
|
||||
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute value
|
||||
@ -293,14 +296,14 @@ struct paged_attention_v1_impl {
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t *__restrict__ out_ptr =
|
||||
scalar_t* __restrict__ out_ptr =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float *__restrict__ prob_vec_ptr =
|
||||
const float* __restrict__ prob_vec_ptr =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t *__restrict__ v_block_cache_ptr =
|
||||
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
@ -311,7 +314,7 @@ struct paged_attention_v1_impl {
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
||||
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
@ -340,16 +343,16 @@ struct paged_attention_v1_impl {
|
||||
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
||||
num_heads);
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables, torch::Tensor &seq_lens,
|
||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -359,68 +362,73 @@ void paged_attention_v1_impl_launcher(
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float *alibi_slopes_ptr =
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
||||
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
seq_lens, max_seq_len, alibi_slopes);
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||
int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables,
|
||||
torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||
@ -434,23 +442,24 @@ namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
||||
struct paged_attention_v2_impl {
|
||||
static void call(
|
||||
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float
|
||||
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int *__restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads, const int max_num_partitions) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
@ -468,8 +477,7 @@ struct paged_attention_v2_impl {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||
|
||||
if (start_token_idx >= seq_len)
|
||||
continue;
|
||||
if (start_token_idx >= seq_len) continue;
|
||||
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
@ -477,15 +485,14 @@ struct paged_attention_v2_impl {
|
||||
const int token_num =
|
||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||
start_token_idx);
|
||||
const int block_num =
|
||||
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_token_num =
|
||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||
const int *seq_block_table = block_tables +
|
||||
const int* seq_block_table = block_tables +
|
||||
max_num_blocks_per_seq * seq_idx +
|
||||
start_token_idx / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t *__restrict__ q_vec_ptr =
|
||||
const scalar_t* __restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
|
||||
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
|
||||
@ -493,10 +500,10 @@ struct paged_attention_v2_impl {
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t *__restrict__ k_block_cache_ptr =
|
||||
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float *__restrict__ head_block_logits =
|
||||
float* __restrict__ head_block_logits =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
@ -510,13 +517,13 @@ struct paged_attention_v2_impl {
|
||||
logits, token_num, block_num * BLOCK_SIZE,
|
||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||
} else {
|
||||
max_and_sum = reduceSoftmax(logits, token_num,
|
||||
block_num * BLOCK_SIZE);
|
||||
max_and_sum =
|
||||
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
auto &&[max_logit, exp_sum] = max_and_sum;
|
||||
auto&& [max_logit, exp_sum] = max_and_sum;
|
||||
|
||||
scalar_t *__restrict__ output_buffer = nullptr;
|
||||
scalar_t* __restrict__ output_buffer = nullptr;
|
||||
if (!no_reduce) {
|
||||
auto idx = seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions + partition_idx;
|
||||
@ -538,13 +545,13 @@ struct paged_attention_v2_impl {
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t *__restrict__ out_ptr =
|
||||
scalar_t* __restrict__ out_ptr =
|
||||
output_buffer + head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float *__restrict__ prob_vec_ptr =
|
||||
const float* __restrict__ prob_vec_ptr =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t *__restrict__ v_block_cache_ptr =
|
||||
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
@ -555,7 +562,7 @@ struct paged_attention_v2_impl {
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
||||
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
@ -587,8 +594,7 @@ struct paged_attention_v2_impl {
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
reducePartitonSoftmax(
|
||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||
@ -603,11 +609,11 @@ struct paged_attention_v2_impl {
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||
constexpr int head_elem_num_per_group =
|
||||
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
|
||||
// didn't align with 64 bytes
|
||||
16; // Note: didn't align with the cacheline size, due to some
|
||||
// HEAD_SIZE didn't align with 64 bytes
|
||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||
const float *__restrict__ rescale_factors = exp_sums;
|
||||
const float* __restrict__ rescale_factors = exp_sums;
|
||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
@ -616,17 +622,16 @@ struct paged_attention_v2_impl {
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
const float *__restrict__ seq_head_rescale_factors =
|
||||
const float* __restrict__ seq_head_rescale_factors =
|
||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions;
|
||||
const scalar_t *__restrict__ seq_head_tmp_out =
|
||||
const scalar_t* __restrict__ seq_head_tmp_out =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
scalar_t *__restrict__ seq_head_output =
|
||||
scalar_t* __restrict__ seq_head_output =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
|
||||
@ -645,21 +650,21 @@ struct paged_attention_v2_impl {
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||
max_num_partitions);
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_impl_launcher(
|
||||
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
|
||||
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -670,73 +675,78 @@ void paged_attention_v2_impl_launcher(
|
||||
int max_num_partitions = exp_sums.size(-1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float *alibi_slopes_ptr =
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
||||
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
|
||||
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
|
||||
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
|
||||
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, \
|
||||
max_seq_len, alibi_slopes);
|
||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
|
||||
alibi_slopes);
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
||||
torch::Tensor &max_logits, torch::Tensor &tmp_out,
|
||||
torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads,
|
||||
float scale, torch::Tensor &block_tables,
|
||||
torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||
|
||||
@ -5,25 +5,26 @@
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(
|
||||
std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
|
||||
const int element_num_per_block, const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size();
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const torch::Tensor& mapping_pairs,
|
||||
const int element_num_per_block,
|
||||
const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size(0);
|
||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int layer = 0; layer < layer_num; ++layer) {
|
||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
|
||||
int64_t source_offset =
|
||||
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||
int64_t target_offset =
|
||||
element_num_per_block * mapping_pairs[pair].second;
|
||||
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t *source_ptr = key_cache_ptr + source_offset;
|
||||
scalar_t *target_ptr = key_cache_ptr + target_offset;
|
||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t* source_ptr = key_cache_ptr + source_offset;
|
||||
scalar_t* target_ptr = key_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
|
||||
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
||||
source_ptr = value_cache_ptr + source_offset;
|
||||
target_ptr = value_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
@ -33,9 +34,9 @@ void copy_blocks_cpu_impl(
|
||||
|
||||
template <typename scalar_t>
|
||||
void reshape_and_cache_cpu_impl(
|
||||
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
|
||||
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
|
||||
const int64_t *__restrict__ slot_mapping, const int num_tokens,
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int num_tokens,
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x) {
|
||||
const int block_elem_num = num_heads * head_size * block_size;
|
||||
@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl(
|
||||
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
||||
int src_value_head_idx =
|
||||
token_idx * value_stride + head_idx * head_size;
|
||||
const scalar_t *src_key_head_ptr = key + src_key_head_idx;
|
||||
const scalar_t *src_value_head_ptr = value + src_value_head_idx;
|
||||
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
|
||||
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
|
||||
const int64_t block_index = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
scalar_t *target_key_head_ptr = key_cache +
|
||||
scalar_t* target_key_head_ptr = key_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
scalar_t *target_value_head_ptr = value_cache +
|
||||
scalar_t* target_value_head_ptr = value_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
|
||||
@ -79,39 +80,31 @@ void reshape_and_cache_cpu_impl(
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
unsigned num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
|
||||
mapping_pairs.reserve(block_mapping.size());
|
||||
for (const auto &pair : block_mapping) {
|
||||
for (const auto &dst : pair.second) {
|
||||
mapping_pairs.emplace_back(pair.first, dst);
|
||||
}
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||
torch::Tensor &slot_mapping,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
|
||||
int num_tokens = key.size(0);
|
||||
@ -135,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
|
||||
const std::map<int64_t, int64_t> &block_mapping) {
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
}
|
||||
|
||||
@ -2,10 +2,10 @@
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void rms_norm_impl(scalar_t *__restrict__ out,
|
||||
const scalar_t *__restrict__ input,
|
||||
const scalar_t *__restrict__ weight, const float epsilon,
|
||||
const int num_tokens, const int hidden_size) {
|
||||
void rms_norm_impl(scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ weight, const float epsilon,
|
||||
const int num_tokens, const int hidden_size) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||
@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
||||
scalar_t *__restrict__ residual,
|
||||
const scalar_t *__restrict__ weight,
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ residual,
|
||||
const scalar_t* __restrict__ weight,
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||
@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
||||
torch::Tensor &weight, float epsilon) {
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(rms_norm_impl)
|
||||
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
|
||||
torch::Tensor &weight, float epsilon) {
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
|
||||
@ -4,107 +4,107 @@
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_impl(
|
||||
const int64_t
|
||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t
|
||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||
/// [num_tokens, num_heads, head_size]
|
||||
scalar_t
|
||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||
// [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t
|
||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||
/// head_size] or [num_tokens, num_heads,
|
||||
/// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size,
|
||||
const int num_tokens) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
constexpr int ELEM_SIZE = sizeof(scalar_t);
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
|
||||
bool flag = (embed_dim % VEC_ELEM_NUM == 0);
|
||||
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
|
||||
|
||||
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
|
||||
scalar_t* qk) {
|
||||
int j = 0;
|
||||
for (; j < loop_upper; j += VEC_ELEM_NUM) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int64_t out_x = token_head + x_index;
|
||||
const int64_t out_y = token_head + y_index;
|
||||
|
||||
const scalar_vec_t cos(cache_ptr + x_index);
|
||||
const scalar_vec_t sin(cache_ptr + y_index);
|
||||
|
||||
const scalar_vec_t q_x(qk + out_x);
|
||||
const scalar_vec_t q_y(qk + out_y);
|
||||
|
||||
vec_op::FP32Vec8 fp32_cos(cos);
|
||||
vec_op::FP32Vec8 fp32_sin(sin);
|
||||
|
||||
vec_op::FP32Vec8 fp32_q_x(q_x);
|
||||
vec_op::FP32Vec8 fp32_q_y(q_y);
|
||||
|
||||
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||
scalar_vec_t(out1).save(qk + out_x);
|
||||
|
||||
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||
scalar_vec_t(out2).save(qk + out_y);
|
||||
}
|
||||
if (!flag) {
|
||||
for (; j < embed_dim; ++j) {
|
||||
const int x_index = j;
|
||||
const int y_index = embed_dim + j;
|
||||
|
||||
const int64_t out_x = token_head + x_index;
|
||||
const int64_t out_y = token_head + y_index;
|
||||
|
||||
const float fp32_cos = cache_ptr[x_index];
|
||||
const float fp32_sin = cache_ptr[y_index];
|
||||
|
||||
const float fp32_q_x = qk[out_x];
|
||||
const float fp32_q_y = qk[out_y];
|
||||
|
||||
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
for (int i = 0; i < num_heads; ++i) {
|
||||
const int head_idx = i;
|
||||
const int64_t token_head =
|
||||
token_idx * query_stride + head_idx * head_size;
|
||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int64_t out_x = token_head + x_index;
|
||||
const int64_t out_y = token_head + y_index;
|
||||
|
||||
const scalar_vec_t cos(cache_ptr + x_index);
|
||||
const scalar_vec_t sin(cache_ptr + y_index);
|
||||
|
||||
const scalar_vec_t q_x(query + out_x);
|
||||
const scalar_vec_t q_y(query + out_y);
|
||||
|
||||
vec_op::FP32Vec8 fp32_cos(cos);
|
||||
vec_op::FP32Vec8 fp32_sin(sin);
|
||||
|
||||
vec_op::FP32Vec8 fp32_q_x(q_x);
|
||||
vec_op::FP32Vec8 fp32_q_y(q_y);
|
||||
|
||||
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||
scalar_vec_t(out1).save(query + out_x);
|
||||
|
||||
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||
scalar_vec_t(out2).save(query + out_y);
|
||||
}
|
||||
compute_loop(token_head, cache_ptr, query);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_kv_heads; ++i) {
|
||||
const int head_idx = i;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int64_t out_x = token_head + x_index;
|
||||
const int64_t out_y = token_head + y_index;
|
||||
|
||||
const scalar_vec_t cos(cache_ptr + x_index);
|
||||
const scalar_vec_t sin(cache_ptr + y_index);
|
||||
|
||||
const scalar_vec_t k_x(key + out_x);
|
||||
const scalar_vec_t k_y(key + out_y);
|
||||
|
||||
vec_op::FP32Vec8 fp32_cos(cos);
|
||||
vec_op::FP32Vec8 fp32_sin(sin);
|
||||
|
||||
vec_op::FP32Vec8 fp32_k_x(k_x);
|
||||
vec_op::FP32Vec8 fp32_k_y(k_y);
|
||||
|
||||
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
|
||||
scalar_vec_t(out1).save(key + out_x);
|
||||
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
|
||||
scalar_vec_t(out2).save(key + out_y);
|
||||
}
|
||||
compute_loop(token_head, cache_ptr, key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_gptj_impl(
|
||||
const int64_t
|
||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t
|
||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||
/// [num_tokens, num_heads, head_size]
|
||||
scalar_t
|
||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||
// [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t
|
||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||
/// head_size] or [num_tokens, num_heads,
|
||||
/// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size,
|
||||
const int num_tokens) {
|
||||
@ -114,13 +114,13 @@ void rotary_embedding_gptj_impl(
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int i = 0; i < num_heads; ++i) {
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const scalar_t *cos_cache_ptr = cache_ptr;
|
||||
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const scalar_t* cos_cache_ptr = cache_ptr;
|
||||
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||
const int head_idx = i;
|
||||
const int64_t token_head =
|
||||
token_idx * query_stride + head_idx * head_size;
|
||||
scalar_t *head_query = token_head + query;
|
||||
scalar_t* head_query = token_head + query;
|
||||
for (int j = 0; j < embed_dim; j += 1) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = 2 * rot_offset;
|
||||
@ -142,12 +142,12 @@ void rotary_embedding_gptj_impl(
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int i = 0; i < num_kv_heads; ++i) {
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const scalar_t *cos_cache_ptr = cache_ptr;
|
||||
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const scalar_t* cos_cache_ptr = cache_ptr;
|
||||
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||
const int head_idx = i;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
scalar_t *head_key = key + token_head;
|
||||
scalar_t* head_key = key + token_head;
|
||||
for (int j = 0; j < embed_dim; j += 1) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = 2 * rot_offset;
|
||||
@ -165,11 +165,11 @@ void rotary_embedding_gptj_impl(
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
||||
torch::Tensor &key, int head_size,
|
||||
torch::Tensor &cos_sin_cache, bool is_neox) {
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||
int num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
|
||||
@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached "
|
||||
"keys/values using PagedAttention.");
|
||||
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
ops.def("rms_norm", &rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
ops.def("rotary_embedding", &rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def("swap_blocks", &swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def("copy_blocks", ©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
@ -17,9 +17,14 @@
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
||||
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
||||
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
||||
__shfl_xor(var, lane_mask, width)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
@ -28,6 +33,13 @@
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
||||
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
||||
#else
|
||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
@ -35,4 +47,3 @@
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
|
||||
@ -2,9 +2,6 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
int get_device_attribute(int attribute, int device_id);
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id);
|
||||
int get_max_shared_memory_per_block_device_attribute(int device_id);
|
||||
|
||||
@ -2,34 +2,28 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
{
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
cudaGetDevice(&device);
|
||||
}
|
||||
else {
|
||||
device = device_id;
|
||||
}
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||
return value;
|
||||
int get_device_attribute(int attribute, int device_id) {
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
cudaGetDevice(&device);
|
||||
} else {
|
||||
device = device_id;
|
||||
}
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
||||
device);
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id)
|
||||
{
|
||||
int attribute;
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
int get_max_shared_memory_per_block_device_attribute(int device_id) {
|
||||
int attribute;
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
|
||||
#ifdef USE_ROCM
|
||||
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||
#else
|
||||
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||
#endif
|
||||
|
||||
return get_device_attribute(attribute, device_id);
|
||||
return get_device_attribute(attribute, device_id);
|
||||
}
|
||||
|
||||
@ -7,11 +7,11 @@
|
||||
|
||||
// fake pointer type
|
||||
using fptr_t = uint64_t;
|
||||
static_assert(sizeof(void *) == sizeof(fptr_t));
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor &t) {
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
|
||||
bool full_nvlink) {
|
||||
auto inp_size = inp.numel() * inp.element_size();
|
||||
// custom allreduce requires input byte size to be multiples of 16
|
||||
@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
return false;
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
cudaStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
||||
reinterpret_cast<float *>(out.data_ptr()),
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
||||
reinterpret_cast<half *>(out.data_ptr()),
|
||||
out.numel());
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
||||
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out) {
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
return fa->get_graph_buffer_ipc_meta();
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
||||
|
||||
@ -31,9 +31,9 @@ struct Signal {
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
||||
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
|
||||
|
||||
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
|
||||
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half &assign_add(half &a, half b) {
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
||||
DINLINE float& assign_add(float& a, float b) { return a += b; }
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||
@ -80,14 +80,14 @@ template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
|
||||
int rank) {
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
|
||||
int rank) {
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x]);
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
|
||||
volatile Signal* self_sg, T* __restrict__ result,
|
||||
int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P *)result)[idx] =
|
||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||
return (P *)(((Signal *)sg) + 1);
|
||||
DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
|
||||
volatile Signal* self_sg, T* __restrict__ result,
|
||||
int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P *ptrs[ngpus];
|
||||
P *tmps[ngpus];
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P *)_dp->ptrs[target];
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P *)result)[dst_idx] = tmps[i][idx];
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -261,14 +258,14 @@ class CustomAllreduce {
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void *, RankData *> buffers_;
|
||||
Signal *self_sg_;
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void *> graph_unreg_buffers_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char *> ipc_handles_;
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
@ -279,22 +276,22 @@ class CustomAllreduce {
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
|
||||
const cudaIpcMemHandle_t *handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
|
||||
const cudaIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets, int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal *rank_sg;
|
||||
Signal* rank_sg;
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(&handles[i]);
|
||||
char* handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal *)handle;
|
||||
rank_sg = (Signal*)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
@ -302,13 +299,13 @@ class CustomAllreduce {
|
||||
}
|
||||
}
|
||||
|
||||
char *open_ipc_handle(const void *ipc_handle) {
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] =
|
||||
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
||||
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char *ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t *)ipc_handle),
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t*)ipc_handle),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
@ -323,7 +320,7 @@ class CustomAllreduce {
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void *base_ptr;
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
@ -331,8 +328,8 @@ class CustomAllreduce {
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
||||
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
@ -344,13 +341,13 @@ class CustomAllreduce {
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, void *self) {
|
||||
void register_buffer(const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, void* self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(handles[i].data());
|
||||
char* handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
@ -371,17 +368,17 @@ class CustomAllreduce {
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto &rd = rank_data[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char *handle =
|
||||
char* handle =
|
||||
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
@ -405,7 +402,7 @@ class CustomAllreduce {
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
@ -418,7 +415,7 @@ class CustomAllreduce {
|
||||
std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData *ptrs;
|
||||
RankData* ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
|
||||
@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_data(T *data, int size, int myRank) {
|
||||
__global__ void set_data(T* data, int size, int myRank) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
data[idx] = myRank * 0.11f;
|
||||
@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||
double *fdata2, int size) {
|
||||
__global__ void convert_data(const T* data1, const T* data2, double* fdata1,
|
||||
double* fdata2, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
fdata1[idx] = data1[idx];
|
||||
@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||
__global__ void init_rand(curandState_t* state, int size, int nRanks) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||
__global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
|
||||
int myRank, int nRanks, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
|
||||
int data_size, bool performance_test) {
|
||||
T *result;
|
||||
T* result;
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
||||
@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
|
||||
cudaIpcMemHandle_t self_data_handle;
|
||||
cudaIpcMemHandle_t data_handles[8];
|
||||
vllm::Signal *buffer;
|
||||
T *self_data_copy;
|
||||
vllm::Signal* buffer;
|
||||
T* self_data_copy;
|
||||
/**
|
||||
* Allocate IPC buffer
|
||||
*
|
||||
@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
||||
MPI_BYTE, MPI_COMM_WORLD));
|
||||
|
||||
void *rank_data;
|
||||
void* rank_data;
|
||||
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||
std::vector<int64_t> offsets(nRanks, 0);
|
||||
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||
offsets, myRank);
|
||||
auto *self_data =
|
||||
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
auto* self_data =
|
||||
reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
// hack buffer registration
|
||||
{
|
||||
std::vector<std::string> handles;
|
||||
handles.reserve(nRanks);
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
char *begin = (char *)&data_handles[i];
|
||||
char *end = (char *)&data_handles[i + 1];
|
||||
char* begin = (char*)&data_handles[i];
|
||||
char* end = (char*)&data_handles[i + 1];
|
||||
handles.emplace_back(begin, end);
|
||||
}
|
||||
std::vector<int64_t> offsets(nRanks,
|
||||
@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
fa.register_buffer(handles, offsets, self_data);
|
||||
}
|
||||
|
||||
double *ground_truth;
|
||||
double* ground_truth;
|
||||
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||
curandState_t *states;
|
||||
curandState_t* states;
|
||||
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||
@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
CUDACHECK(cudaStreamDestroy(stream));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int main(int argc, char** argv) {
|
||||
int nRanks, myRank;
|
||||
MPICHECK(MPI_Init(&argc, &argv));
|
||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||
@ -296,7 +296,7 @@ int main(int argc, char **argv) {
|
||||
ncclUniqueId id;
|
||||
ncclComm_t comm;
|
||||
if (myRank == 0) ncclGetUniqueId(&id);
|
||||
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
||||
MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
|
||||
MPI_COMM_WORLD));
|
||||
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||
|
||||
|
||||
@ -6,32 +6,30 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
@ -11,26 +11,24 @@
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
template <typename scalar_t>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||
and the associated type conversions within HIP/CUDA. These helpers need
|
||||
to be implemented for now because the relevant type conversion
|
||||
@ -54,51 +52,68 @@ __global__ void rms_norm_kernel(
|
||||
|
||||
Each struct should have the member static constexpr bool `exists`:
|
||||
If false, the optimized kernel is not used for the corresponding torch type.
|
||||
If true, the struct should be fully defined as shown in the examples below.
|
||||
If true, the struct should be fully defined as shown in the examples below.
|
||||
*/
|
||||
template<typename torch_type>
|
||||
struct _typeConvert { static constexpr bool exists = false; };
|
||||
template <typename torch_type>
|
||||
struct _typeConvert {
|
||||
static constexpr bool exists = false;
|
||||
};
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template<>
|
||||
template <>
|
||||
struct _typeConvert<c10::Half> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
||||
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||
template<>
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
||||
__device__ static inline float convert(hip_type x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||
// 12000))
|
||||
|
||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||
Only functions that are necessary in that kernel are implemented.
|
||||
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||
*/
|
||||
template<typename scalar_t, int width>
|
||||
template <typename scalar_t, int width>
|
||||
struct alignas(16) _f16Vec {
|
||||
/* Not theoretically necessary that width is a power of 2 but should
|
||||
almost always be the case for optimization purposes */
|
||||
/* Not theoretically necessary that width is a power of 2 but should
|
||||
almost always be the case for optimization purposes */
|
||||
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||||
"Width is not a positive power of 2!");
|
||||
using Converter = _typeConvert<scalar_t>;
|
||||
@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
|
||||
|
||||
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i+1]};
|
||||
temp += T2{other.data[i], other.data[i+1]};
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp += T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i+1] = temp.y;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i)
|
||||
data[i] += other.data[i];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i+1]};
|
||||
temp *= T2{other.data[i], other.data[i+1]};
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp *= T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i+1] = temp.y;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i)
|
||||
data[i] *= other.data[i];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const float scale) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
|
||||
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
|
||||
temp_f.x *= scale;
|
||||
temp_f.y *= scale;
|
||||
T2 temp = Converter::convert(temp_f);
|
||||
data[i] = temp.x;
|
||||
data[i+1] = temp.y;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float temp = Converter::convert(data[i]) * scale;
|
||||
data[i] = Converter::convert(temp);
|
||||
@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
|
||||
__device__ float sum_squares() const {
|
||||
float result = 0.0f;
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i+1]});
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
result += z.x * z.x + z.y * z.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float x = Converter::convert(data[i]);
|
||||
result += x * x;
|
||||
@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck. */
|
||||
template<typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<
|
||||
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
@ -203,9 +214,12 @@ __global__ std::enable_if_t<
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
auto* __restrict__ input_v =
|
||||
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||
auto* __restrict__ residual_v =
|
||||
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||
auto* __restrict__ weight_v =
|
||||
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
@ -215,10 +229,11 @@ __global__ std::enable_if_t<
|
||||
residual_v[id] = temp;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else variance = blockReduceSum<float, 256>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
@ -233,52 +248,50 @@ __global__ std::enable_if_t<
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Generic fused_add_rms_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template<typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<
|
||||
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||
z += residual[blockIdx.x * hidden_size + idx];
|
||||
float x = (float) z;
|
||||
float x = (float)z;
|
||||
variance += x * x;
|
||||
residual[blockIdx.x * hidden_size + idx] = z;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else variance = blockReduceSum<float, 256>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
@ -286,40 +299,27 @@ void rms_norm(
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"fused_add_rms_norm_kernel", \
|
||||
[&] { \
|
||||
vllm::fused_add_rms_norm_kernel \
|
||||
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), \
|
||||
epsilon, \
|
||||
num_tokens, \
|
||||
hidden_size); \
|
||||
});
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), epsilon, \
|
||||
num_tokens, hidden_size); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
@ -342,8 +342,8 @@ void fused_add_rms_norm(
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
||||
&& wt_ptr % 16 == 0;
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
|
||||
@ -3,5 +3,6 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
||||
m.def("topk_softmax", &topk_softmax,
|
||||
"Apply topk softmax to the gating outputs.");
|
||||
}
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
||||
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
||||
|
||||
@ -19,15 +19,22 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||
@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this way
|
||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||
@ -383,7 +390,7 @@ struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
|
||||
@ -7,119 +7,128 @@
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
|
||||
namespace {
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
||||
// don't worry about overflow because num_experts is relatively small
|
||||
return row * total_col + col;
|
||||
}
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||
int32_t col) {
|
||||
// don't worry about overflow because num_experts is relatively small
|
||||
return row * total_col + col;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
int32_t *sorted_token_ids,
|
||||
int32_t *expert_ids,
|
||||
int32_t *total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel) {
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids,
|
||||
int32_t* total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size, size_t numel) {
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
int32_t* tokens_cnts =
|
||||
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||
int32_t* cumsum =
|
||||
shared_mem + (num_experts + 1) *
|
||||
num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||
* which counts how many tokens in the token shard of thread_index are
|
||||
* assigned to expert expert_index.
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// We accumulate the token counts of all experts in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i - 1] +
|
||||
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
||||
block_size) *
|
||||
block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
/**
|
||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||
* which counts how many tokens in the token shard of thread_index are assigned
|
||||
* to expert expert_index.
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
__syncthreads();
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding
|
||||
* blocks and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// We accumulate the token counts of all experts in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
||||
* and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Each thread processes a token shard, calculating the index of each token after
|
||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
||||
* where * represents a padding value(preset in python).
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
||||
* stores the indices of the tokens processed by the expert with expert_id within
|
||||
* the current thread's token shard.
|
||||
*/
|
||||
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Each thread processes a token shard, calculating the index of each token
|
||||
* after sorting by expert number. Given the example topk_ids =
|
||||
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
|
||||
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
|
||||
* padding value(preset in python).
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and
|
||||
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
||||
* processed by the expert with expert_id within the current thread's token
|
||||
* shard.
|
||||
*/
|
||||
int32_t rank_post_pad =
|
||||
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
||||
cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
|
||||
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||
int block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t shared_mem =
|
||||
((num_experts + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||
AT_CUDA_CHECK(
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
271
csrc/ops.h
271
csrc/ops.h
@ -3,204 +3,139 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, float epsilon);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
);
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes
|
||||
);
|
||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes);
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy);
|
||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||
int thy);
|
||||
|
||||
torch::Tensor marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k);
|
||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor &a,
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &g_idx,
|
||||
torch::Tensor &perm,
|
||||
torch::Tensor &workspace,
|
||||
int64_t num_bits,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full);
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &perm,
|
||||
int64_t size_k,
|
||||
int64_t size_n,
|
||||
int64_t num_bits);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama,
|
||||
int bit);
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
int bit);
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
bool use_exllama, int bit);
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
|
||||
int block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = uint64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink);
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int rank,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out);
|
||||
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out);
|
||||
void dispose(fptr_t _fa);
|
||||
int meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets);
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
||||
@ -7,14 +7,10 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding(
|
||||
scalar_t* __restrict__ arr,
|
||||
const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim)
|
||||
{
|
||||
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
if (IS_NEOX) {
|
||||
@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* cache_ptr,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int rot_dim,
|
||||
const int token_idx,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride)
|
||||
{
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||
const int64_t query_stride, const int64_t key_stride) {
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
__global__ void batched_rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||
// or [num_tokens]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
const scalar_t* cache_ptr =
|
||||
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
@ -135,36 +141,21 @@ void rotary_embedding(
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
|
||||
and process in batched manner.
|
||||
*/
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox, int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||
) {
|
||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
@ -191,36 +183,21 @@ void batched_rotary_embedding(
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
if (is_neox) {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 2816) \
|
||||
f(in_T, out_T, W_T, narrow, 3072) \
|
||||
f(in_T, out_T, W_T, narrow, 3328) \
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 6144) \
|
||||
f(in_T, out_T, W_T, narrow, 6400) \
|
||||
f(in_T, out_T, W_T, narrow, 6848) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
f(in_T, out_T, W_T, narrow, 24576) \
|
||||
f(in_T, out_T, W_T, narrow, 27392) \
|
||||
f(in_T, out_T, W_T, narrow, 27648) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, 2752, narrow) \
|
||||
f(in_T, out_T, W_T, 2816, narrow) \
|
||||
f(in_T, out_T, W_T, 3072, narrow) \
|
||||
f(in_T, out_T, W_T, 3328, narrow) \
|
||||
f(in_T, out_T, W_T, 3456, narrow) \
|
||||
f(in_T, out_T, W_T, 3584, narrow) \
|
||||
f(in_T, out_T, W_T, 4096, narrow) \
|
||||
@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, 5504, narrow) \
|
||||
f(in_T, out_T, W_T, 5632, narrow) \
|
||||
f(in_T, out_T, W_T, 6144, narrow) \
|
||||
f(in_T, out_T, W_T, 6400, narrow) \
|
||||
f(in_T, out_T, W_T, 6848, narrow) \
|
||||
f(in_T, out_T, W_T, 6912, narrow) \
|
||||
f(in_T, out_T, W_T, 7168, narrow) \
|
||||
@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, 22016, narrow) \
|
||||
f(in_T, out_T, W_T, 24576, narrow) \
|
||||
f(in_T, out_T, W_T, 27392, narrow) \
|
||||
f(in_T, out_T, W_T, 27648, narrow) \
|
||||
f(in_T, out_T, W_T, 28672, narrow) \
|
||||
f(in_T, out_T, W_T, 32000, narrow) \
|
||||
f(in_T, out_T, W_T, 32256, narrow) \
|
||||
|
||||
@ -1,8 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cooperative_groups.h>
|
||||
#else
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#endif
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda/pipeline>
|
||||
#endif
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
@ -11,6 +17,24 @@
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
template <size_t len>
|
||||
__host__ __device__
|
||||
inline void* memcpy_blocking(void *dst, const void *src) {
|
||||
// Does not handle the case of long datatypes
|
||||
char *d = reinterpret_cast<char *>(dst);
|
||||
const char *s = reinterpret_cast<const char *>(src);
|
||||
size_t i = 0;
|
||||
#pragma unroll
|
||||
for (i = 0; i < len; ++i) {
|
||||
d[i] = s[i];
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
// nthrs = (32, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
float y = 0;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
x_vec.load(X + (batch_idx * feat_in) +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
}
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
y += sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
y_warpwise[threadIdx.y] = y;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float y_write = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y_write += y_warpwise[i];
|
||||
}
|
||||
|
||||
// write Y;
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + j;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// nthrs = (2, 16, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||
typename in_T, typename out_T, typename W_T>
|
||||
@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
#ifndef USE_ROCM
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
#else
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
#endif
|
||||
}
|
||||
|
||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||
@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
sum = g.shfl(sum, 0);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
#ifndef USE_ROCM
|
||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||
#else
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
scale);
|
||||
}
|
||||
} else {
|
||||
#ifndef USE_ROCM
|
||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||
feat_in % (vec_size * 16) == 0 ||
|
||||
feat_in % (vec_size * 8) == 0);
|
||||
@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
#else
|
||||
constexpr size_t rocm_warp_size = warpSize;
|
||||
|
||||
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
|
||||
feat_in % (rocm_warp_size * vec_size_) == 0
|
||||
|
||||
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
|
||||
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
|
||||
constexpr size_t vec_size_shrink = vec_size_; \
|
||||
constexpr int tx = tx_; \
|
||||
constexpr int ty = ty_; \
|
||||
dim3 nblks(feat_out, batch_size); \
|
||||
dim3 nthrs(tx, ty); \
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
|
||||
vec_size_shrink * sizeof(in_T), \
|
||||
vec_size_shrink * sizeof(W_T), \
|
||||
tx, ty, tz> \
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
|
||||
full_y_size, num_layers, layer_idx, \
|
||||
scale); \
|
||||
}
|
||||
|
||||
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
|
||||
CHECK_INPUT_TILEABLE_BY(16) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 8) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 4) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 2) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 1));
|
||||
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
|
||||
|
||||
#undef CHECK_INPUT_TILEABLE_BY
|
||||
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
#ifndef VEC_DTYPES_CUH_
|
||||
#define VEC_DTYPES_CUH_
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifdef FLASHINFER_USE_FP8
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
@ -10,6 +8,9 @@
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "../type_convert.h"
|
||||
#include "../../cuda_compat.h"
|
||||
|
||||
#define FLASHINFER_INLINE \
|
||||
inline __attribute__((always_inline)) __device__ __host__
|
||||
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include "type_convert.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "bgmv/bgmv_config.h"
|
||||
|
||||
namespace {
|
||||
|
||||
//====== utils ======
|
||||
|
||||
@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//====== pybind ======
|
||||
|
||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||
"dispatch_bgmv_low_level");
|
||||
}
|
||||
11
csrc/punica/punica_ops.h
Normal file
11
csrc/punica/punica_ops.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, float scale);
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
float scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset);
|
||||
13
csrc/punica/punica_pybind.cpp
Normal file
13
csrc/punica/punica_pybind.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "punica_ops.h"
|
||||
|
||||
//====== pybind ======
|
||||
|
||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||
"dispatch_bgmv_low_level");
|
||||
}
|
||||
82
csrc/punica/type_convert.h
Normal file
82
csrc/punica/type_convert.h
Normal file
@ -0,0 +1,82 @@
|
||||
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
#define CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#else
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
|
||||
|
||||
typedef __half nv_half;
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
|
||||
return __hip_bfloat162{val, val};
|
||||
}
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
|
||||
return __hip_bfloat162{vall, valr};
|
||||
}
|
||||
|
||||
template <typename T_src, typename T_dst>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T_dst convert_type(T_src val) {
|
||||
return static_cast<T_dst>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__half, float>(__half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half convert_type<float, __half>(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T vllm_add(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half vllm_add<__half>(__half a, __half b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#undef __TYPE_CONVERT__HOST_DEVICE__
|
||||
|
||||
#endif // USE_ROCM
|
||||
|
||||
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
143
csrc/pybind.cpp
143
csrc/pybind.cpp
@ -8,114 +8,90 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
ops.def("paged_attention_v1", &paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached "
|
||||
"keys/values using PagedAttention.");
|
||||
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
|
||||
ops.def("gelu_and_mul", &gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
|
||||
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
ops.def("rms_norm", &rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
ops.def("rotary_embedding", &rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
ops.def(
|
||||
"batched_rotary_embedding",
|
||||
&batched_rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
||||
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
|
||||
"(supports multiple loras)");
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
||||
ops.def("marlin_gemm", &marlin_gemm,
|
||||
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
|
||||
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
|
||||
"gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||
"gptq_marlin repack from GPTQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
|
||||
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
|
||||
"per-row/column quantization.");
|
||||
#endif
|
||||
|
||||
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||
ops.def(
|
||||
"moe_align_block_size",
|
||||
&moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
|
||||
"Compute FP8 quantized tensor for given scaling factor");
|
||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
|
||||
"Compute FP8 quantized tensor and scaling factor");
|
||||
ops.def("moe_align_block_size", &moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such "
|
||||
"that it is divisible by the block size.");
|
||||
|
||||
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
|
||||
"Compute int8 quantized tensor for given scaling factor");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache_flash",
|
||||
&reshape_and_cache_flash,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"convert_fp8",
|
||||
&convert_fp8,
|
||||
"Convert the key and value cache to fp8 data type");
|
||||
cache_ops.def("swap_blocks", &swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def("copy_blocks", ©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def("reshape_and_cache", &reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def("convert_fp8", &convert_fp8,
|
||||
"Convert the key and value cache to fp8 data type");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
pybind11::module cuda_utils =
|
||||
m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def("get_device_attribute", &get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Custom all-reduce kernels
|
||||
@ -132,5 +108,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||
"register_graph_buffers");
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
@ -25,32 +25,28 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
|
||||
namespace vllm {
|
||||
namespace aqlm {
|
||||
|
||||
__global__ void Code1x16MatVec(
|
||||
const int4* __restrict__ A,
|
||||
const int4* __restrict__ B,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
const int prob_m,
|
||||
const int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
|
||||
const int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8)
|
||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
@ -76,22 +71,19 @@ __global__ void Code1x16MatVec(
|
||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t dec[4];
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||
// actually help us; this brings > 2x speedup.
|
||||
asm volatile (
|
||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*) &codebook[enc[i]])
|
||||
);
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||
// that doesn't actually help us; this brings > 2x speedup.
|
||||
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*)&codebook[enc[i]]));
|
||||
half2* a = reinterpret_cast<half2*>(&dec);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
res2 = __hfma2(a[j], b[j], res2);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
b_sh_rd++;
|
||||
}
|
||||
@ -100,37 +92,33 @@ __global__ void Code1x16MatVec(
|
||||
}
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2)
|
||||
res += __shfl_down_sync(0xffffffff, res, i);
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Code2x8MatVec(
|
||||
const int4* __restrict__ A,
|
||||
const int4* __restrict__ B,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,9 +136,8 @@ __global__ void Code2x8MatVec(
|
||||
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++)
|
||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8)
|
||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
@ -170,13 +156,15 @@ __global__ void Code2x8MatVec(
|
||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2* a0 =
|
||||
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 =
|
||||
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
@ -187,36 +175,31 @@ __global__ void Code2x8MatVec(
|
||||
}
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2)
|
||||
res += __shfl_down_sync(0xffffffff, res, i);
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void Code1x16Dequant(
|
||||
const int4* __restrict__ A,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
|
||||
const int codebook_stride // as int4
|
||||
const int4* __restrict__ A, int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long, sums to m.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,17 +214,15 @@ __global__ void Code1x16Dequant(
|
||||
while (iters--) {
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||
// actually help us; this brings > 2x speedup.
|
||||
asm volatile (
|
||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*) &codebook[enc[i]])
|
||||
);
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||
// that doesn't actually help us; this brings > 2x speedup.
|
||||
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*)&codebook[enc[i]]));
|
||||
|
||||
C[a_gl_rd * 8 + i] = chunk;
|
||||
}
|
||||
@ -250,28 +231,25 @@ __global__ void Code1x16Dequant(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void Code2x8Dequant(
|
||||
const int4* __restrict__ A,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
const int4* __restrict__ A, int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
@ -290,9 +268,8 @@ __global__ void Code2x8Dequant(
|
||||
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++)
|
||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -302,12 +279,14 @@ __global__ void Code2x8Dequant(
|
||||
while (iters--) {
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
#pragma unroll
|
||||
half2* a0 =
|
||||
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 =
|
||||
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||
C[a_gl_rd * 8 + i] = chunk;
|
||||
@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
|
||||
}
|
||||
}
|
||||
|
||||
inline int ceildiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
const int THREAD_M = 16;
|
||||
|
||||
void code1x16_matvec_cuda(
|
||||
const void* __restrict__ A,
|
||||
const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes,
|
||||
const int codebook_stride
|
||||
) {
|
||||
void code1x16_matvec_cuda(const void* __restrict__ A,
|
||||
const void* __restrict__ B, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m,
|
||||
int prob_k, const int4 codebook_a_sizes,
|
||||
const int codebook_stride) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
@ -345,28 +317,16 @@ void code1x16_matvec_cuda(
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
|
||||
(const int4*) A,
|
||||
(const int4*) B,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
||||
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||
prob_k, codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
void code2x8_matvec_cuda(
|
||||
const void* __restrict__ A,
|
||||
const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes,
|
||||
const int codebook_stride
|
||||
) {
|
||||
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m,
|
||||
int prob_k, const int4 codebook_a_sizes,
|
||||
const int codebook_stride) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
@ -379,30 +339,20 @@ void code2x8_matvec_cuda(
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
cudaFuncSetAttribute(
|
||||
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||
);
|
||||
cudaFuncSetAttribute(Code2x8MatVec,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*) A,
|
||||
(const int4*) B,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||
prob_k, codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
void code1x16_dequant_cuda(
|
||||
const void* __restrict__ A,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
const void* __restrict__ A, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
@ -417,25 +367,21 @@ void code1x16_dequant_cuda(
|
||||
int threads = 32 * thread_m;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||
(const int4*) A,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
codebook_stride // as int4.
|
||||
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long.
|
||||
codebook_stride // as int4.
|
||||
);
|
||||
}
|
||||
|
||||
// Dequantizes the code and codebook into weights.
|
||||
void code2x8_dequant_cuda(
|
||||
const void* __restrict__ A,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
void code2x8_dequant_cuda(
|
||||
const void* __restrict__ A, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
@ -451,74 +397,50 @@ void code2x8_dequant_cuda(
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
cudaFuncSetAttribute(
|
||||
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||
);
|
||||
cudaFuncSetAttribute(Code2x8Dequant,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*) A,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||
codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
int codebook_stride(const torch::Tensor& codebooks)
|
||||
{
|
||||
int codebook_stride(const torch::Tensor& codebooks) {
|
||||
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||
}
|
||||
|
||||
void code1x16_matvec(
|
||||
const torch::Tensor& A,
|
||||
const torch::Tensor& B,
|
||||
torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
|
||||
code1x16_matvec_cuda(
|
||||
A.data_ptr(),
|
||||
B.data_ptr(),
|
||||
C.data_ptr(),
|
||||
codebook.data_ptr(),
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride(codebook)
|
||||
);
|
||||
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||
codebook_stride(codebook));
|
||||
}
|
||||
|
||||
torch::Tensor code1x16_matmat(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
torch::Tensor code1x16_matmat(const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(input.dtype())
|
||||
.device(input.device())
|
||||
);
|
||||
auto flat_output = torch::empty(
|
||||
{flat_input.size(0), out_features},
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code1x16_matvec(
|
||||
codes.squeeze(2),
|
||||
input_vec,
|
||||
output_vec,
|
||||
codebooks,
|
||||
codebook_a_sizes
|
||||
);
|
||||
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||
codebook_a_sizes);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
|
||||
@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
|
||||
return output;
|
||||
}
|
||||
|
||||
void code2x8_matvec(
|
||||
const torch::Tensor& A,
|
||||
const torch::Tensor& B,
|
||||
torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes
|
||||
) {
|
||||
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
|
||||
torch::Tensor& C, const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
code2x8_matvec_cuda(
|
||||
A.data_ptr(),
|
||||
B.data_ptr(),
|
||||
C.data_ptr(),
|
||||
codebook.data_ptr(),
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
2 * codebook_stride(codebook)
|
||||
);
|
||||
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||
2 * codebook_stride(codebook));
|
||||
}
|
||||
|
||||
torch::Tensor code2x8_matmat(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
) {
|
||||
torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(input.dtype())
|
||||
.device(input.device())
|
||||
);
|
||||
auto flat_output = torch::empty(
|
||||
{flat_input.size(0), out_features},
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code2x8_matvec(
|
||||
codes.squeeze(2),
|
||||
input_vec,
|
||||
output_vec,
|
||||
codebooks,
|
||||
codebook_a_sizes
|
||||
);
|
||||
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||
codebook_a_sizes);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
if (bias.has_value()) {
|
||||
@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat(
|
||||
}
|
||||
|
||||
// Accumulate the partition sizes.
|
||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
||||
{
|
||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes;
|
||||
auto cumulative_size = &cumulative_sizes.x;
|
||||
int i = 0;
|
||||
int last = 0;
|
||||
assert(codebook_partition_sizes.size(0) <= 4);
|
||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
|
||||
{
|
||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||
last = *cumulative_size;
|
||||
}
|
||||
// fill in the rest with unreachable.
|
||||
for (; i < 4; ++i, ++cumulative_size)
|
||||
{
|
||||
*cumulative_size = last*10;
|
||||
for (; i < 4; ++i, ++cumulative_size) {
|
||||
*cumulative_size = last * 10;
|
||||
}
|
||||
return cumulative_sizes;
|
||||
}
|
||||
|
||||
} // namespace aqlm
|
||||
} // namespace vllm
|
||||
} // namespace aqlm
|
||||
} // namespace vllm
|
||||
|
||||
|
||||
torch::Tensor aqlm_gemm(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
)
|
||||
{
|
||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16))
|
||||
{
|
||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||
if (nbooks == 1 && entries == (1 << 16)) {
|
||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
|
||||
cumulative_sizes, bias);
|
||||
}
|
||||
if (nbooks == 2 && entries == (1 << 8))
|
||||
{
|
||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||
if (nbooks == 2 && entries == (1 << 8)) {
|
||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
|
||||
cumulative_sizes, bias);
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||
" entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes
|
||||
)
|
||||
{
|
||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const entries = codebooks.size(1);
|
||||
@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant(
|
||||
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
||||
|
||||
auto weights = torch::empty({out_features, in_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(codebooks.dtype())
|
||||
.device(codebooks.device())
|
||||
);
|
||||
torch::TensorOptions()
|
||||
.dtype(codebooks.dtype())
|
||||
.device(codebooks.device()));
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16))
|
||||
{
|
||||
vllm::aqlm::code1x16_dequant_cuda(
|
||||
codes.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
codebooks.data_ptr(),
|
||||
out_features,
|
||||
in_features,
|
||||
cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
if (nbooks == 1 && entries == (1 << 16)) {
|
||||
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||
codebooks.data_ptr(), out_features,
|
||||
in_features, cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
|
||||
// weights *= scales.index({"...", 0, 0});
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||
// and not consistent with gemv implementation.) weights *=
|
||||
// scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
return weights;
|
||||
}
|
||||
|
||||
if (nbooks == 2 && entries == (1 << 8))
|
||||
{
|
||||
vllm::aqlm::code2x8_dequant_cuda(
|
||||
codes.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
codebooks.data_ptr(),
|
||||
out_features,
|
||||
in_features,
|
||||
cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
if (nbooks == 2 && entries == (1 << 8)) {
|
||||
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||
codebooks.data_ptr(), out_features,
|
||||
in_features, cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
|
||||
// weights *= scales.index({"...", 0, 0});
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||
// and not consistent with gemv implementation) weights *=
|
||||
// scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
return weights;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||
" entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
Modified from NVIDIA FasterTransformer:
|
||||
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
uint4 result;
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is
|
||||
// thanks to the register packing format and the fact that we force our
|
||||
// integers to be unsigned, and account for this in the fp16 subtractions. In
|
||||
// addition, I exploit the fact that sub and fma have the same throughput in
|
||||
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
|
||||
// the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
|
||||
// dependency if we issue immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||
"n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
// I use inline PTX below because I am not sure if the compiler will emit
|
||||
// float2half instructions if I use the half2 ctor. In this case, I chose
|
||||
// performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
// Haotian: Let's use {-64, -64}.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
// Haotian: Let's use {-64, -64}.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
@ -20,26 +18,20 @@ namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
unsigned v0 = *((unsigned short *)&x);
|
||||
unsigned v1 = *((unsigned short *)&y);
|
||||
static inline __device__ __host__ unsigned __pack_half2(const half x,
|
||||
const half y) {
|
||||
unsigned v0 = *((unsigned short*)&x);
|
||||
unsigned v1 = *((unsigned short*)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
template<int N>
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
int G,
|
||||
int split_k_iters,
|
||||
half* __restrict__ A,
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
int M,
|
||||
int IC,
|
||||
int OC,
|
||||
half* __restrict__ C)
|
||||
{
|
||||
template <int N>
|
||||
__global__ void __launch_bounds__(64)
|
||||
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
|
||||
half* __restrict__ A, int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros, int M, int IC,
|
||||
int OC, half* __restrict__ C) {
|
||||
// Only support matrix n = 64 or 128
|
||||
assert(N == 64 || N == 128);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
bool ld_A_flag =
|
||||
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
|
||||
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
half* A_ptr =
|
||||
A +
|
||||
(((int)blockIdx_y) / j_factors1 * 16 +
|
||||
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
|
||||
IC +
|
||||
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
|
||||
(((int)threadIdx.x) / (N / 8)) * (OC / 8) +
|
||||
(((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||
(((int)threadIdx.x) % (N / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
half* A_shared_ptr = A_shared +
|
||||
((int)threadIdx.y) * row_stride_warp * (32 + 8) +
|
||||
(((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
|
||||
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
half* B_shared_ptr = B_shared +
|
||||
((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
|
||||
(((int)threadIdx.x) / (N / 8)) * (N + 8) +
|
||||
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ ((int)threadIdx.x) % (N / 8);
|
||||
int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||
((int)threadIdx.x) % (N / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
half* scaling_factors_ptr = scaling_factors +
|
||||
(((int)blockIdx_y) % j_factors1) * N +
|
||||
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ ((int)threadIdx.y) * (N / 2)
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
half* C_ptr =
|
||||
C +
|
||||
static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
|
||||
(((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
if (ld_A_flag) {
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
uint4 B_loaded_scale =
|
||||
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
|
||||
threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
|
||||
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
|
||||
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
|
||||
// zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
|
||||
// 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
|
||||
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
|
||||
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
|
||||
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
|
||||
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded =
|
||||
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
|
||||
// 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
|
||||
// % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
|
||||
// q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
|
||||
0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
|
||||
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
|
||||
B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||
"addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
|
||||
(((((int)threadIdx.x) & 15) * 40) +
|
||||
((((int)threadIdx.x) >> 4) * 8)))));
|
||||
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"=r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"=r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"=r"(((unsigned*)(A_shared_warp + 0))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||
"addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
|
||||
(((int)threadIdx.y) * (N / 2))) +
|
||||
(ax1_0 * 16))])) +
|
||||
(((((int)threadIdx.x) & 15) * (N + 8)) +
|
||||
((((int)threadIdx.x) >> 4) * 8)))));
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
|
||||
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
|
||||
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
|
||||
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#else
|
||||
#else
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||
"%13};\n"
|
||||
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||
"%13};\n"
|
||||
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
|
||||
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M) {
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
|
||||
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
half* __restrict__ C,
|
||||
int G
|
||||
)
|
||||
{
|
||||
__global__ void __launch_bounds__(64)
|
||||
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros, half* __restrict__ C, int G) {
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
||||
|
||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.x)
|
||||
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.y)
|
||||
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.z)
|
||||
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(B_loaded_fp16.w)
|
||||
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
|
||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||
|
||||
@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy)
|
||||
{
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
int G = in_c / _scaling_factors.size(0);
|
||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||
int thy) {
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
int G = in_c / _scaling_factors.size(0);
|
||||
|
||||
int x_thread = thx;
|
||||
int y_thread = thy;
|
||||
int x_thread = thx;
|
||||
int y_thread = thy;
|
||||
|
||||
int x_blocks = 1;
|
||||
int y_blocks = 1;
|
||||
if (thx==0) {
|
||||
x_thread = qout_c;
|
||||
}
|
||||
if (thy==0) {
|
||||
y_thread = in_c;
|
||||
}
|
||||
if (thx==0 && thy==0) {
|
||||
x_thread = 8;
|
||||
y_thread = 8;
|
||||
x_blocks = (int)(qout_c / 8);
|
||||
y_blocks = (int)(in_c / 8);
|
||||
}
|
||||
int x_blocks = 1;
|
||||
int y_blocks = 1;
|
||||
if (thx == 0) {
|
||||
x_thread = qout_c;
|
||||
}
|
||||
if (thy == 0) {
|
||||
y_thread = in_c;
|
||||
}
|
||||
if (thx == 0 && thy == 0) {
|
||||
x_thread = 8;
|
||||
y_thread = 8;
|
||||
x_blocks = (int)(qout_c / 8);
|
||||
y_blocks = (int)(in_c / 8);
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(_scaling_factors.dtype())
|
||||
.device(_scaling_factors.device());
|
||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||
auto scaling_factors =
|
||||
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_thread, y_thread);
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_thread, y_thread);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
kernel, scaling_factors, zeros, de_kernel, G);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
kernel, scaling_factors, zeros, de_kernel, G);
|
||||
|
||||
return _de_kernel;
|
||||
return _de_kernel;
|
||||
}
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
@ -386,61 +489,61 @@ torch::Tensor awq_dequantize(
|
||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||
// assume that batch_size < 16 for now
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters)
|
||||
{
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
int split_k_iters) {
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(_in_feats.dtype())
|
||||
.device(_in_feats.device());
|
||||
at::Tensor _out_feats =
|
||||
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors =
|
||||
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
|
||||
if (num_out_channels % 64 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||
if (num_out_channels % 8 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||
if (group_size % 32 != 0)
|
||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||
if (num_out_channels % group_size != 0)
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
if (num_out_channels % 64 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||
if (num_out_channels % 8 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||
if (group_size % 32 != 0)
|
||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||
if (num_out_channels % group_size != 0)
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (num_out_channels % 128 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (num_out_channels % 128 == 0) {
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
|
||||
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
} else if (num_out_channels % 64 == 0) {
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
|
||||
split_k_iters);
|
||||
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
|
||||
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
||||
|
||||
62
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Normal file
62
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Normal file
@ -0,0 +1,62 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cmath>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
static const float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
static const float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
// round
|
||||
float dst = std::nearbyint(x);
|
||||
// saturate
|
||||
dst = std::clamp(dst, i8_min, i8_max);
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
// CUDA path
|
||||
uint32_t dst;
|
||||
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
||||
return reinterpret_cast<const int8_t&>(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename scale_type>
|
||||
__global__ void static_scaled_int8_quant_kernel(
|
||||
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
|
||||
const scale_type* scale_ptr, const int hidden_size) {
|
||||
const int tid = threadIdx.x;
|
||||
const int token_idx = blockIdx.x;
|
||||
scale_type scale = *scale_ptr;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[token_idx * hidden_size + i] =
|
||||
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& scale) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), hidden_size);
|
||||
});
|
||||
}
|
||||
346
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
Normal file
346
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
Normal file
@ -0,0 +1,346 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||
// https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either
|
||||
// row/column or scalar broadcasting where the tensor being loaded from is
|
||||
// always passed in via a device pointer. This lets one compiled kernel handle
|
||||
// all cases of per-tensor or per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graph
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::epilogue::threadblock {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->row_broadcast) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>
|
||||
>
|
||||
struct VisitorColOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gCol,
|
||||
RTensor&& tC_rCol,
|
||||
CTensor&& tC_cCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gCol(cute::forward<GTensor>(tC_gCol)),
|
||||
tC_rCol(cute::forward<RTensor>(tC_rCol)),
|
||||
tC_cCol(cute::forward<CTensor>(tC_cCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gCol;
|
||||
RTensor tC_rCol;
|
||||
CTensor tC_cCol;
|
||||
Params const* params_ptr;
|
||||
int m;
|
||||
|
||||
// This function is modified from VisitorColBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rCol);
|
||||
|
||||
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||
}
|
||||
|
||||
if (params_ptr->col_broadcast) {
|
||||
// In this case we are loading from a column vector and broadcasting
|
||||
copy_if(pred, tC_gCol, tC_rCol);
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
auto dst_v = filter(tC_rCol);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(dst_v); ++i) {
|
||||
if (pred(i)) {
|
||||
dst_v(i) = *(params_ptr->ptr_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
frg_col.fill(tC_rCol(row_idx,iter_idx));
|
||||
return frg_col;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mCol = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_col),
|
||||
problem_shape,
|
||||
params_ptr->dCol);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
|
||||
Tensor tC_gCol = group_modes<1,4>(
|
||||
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
Tensor tC_rCol = make_tensor_like(tC_gCol);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tC_cCol = group_modes<1,4>(
|
||||
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gCol), decltype(tC_rCol),
|
||||
decltype(tC_cCol), ProblemShape>(
|
||||
cute::move(tC_gCol),
|
||||
cute::move(tC_rCol),
|
||||
cute::move(tC_cCol),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
@ -0,0 +1,389 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
|
||||
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcast {
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
|
||||
|
||||
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
|
||||
struct SharedStorage {
|
||||
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params),
|
||||
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
|
||||
|
||||
Params params;
|
||||
Element* smem_row;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||
}
|
||||
|
||||
template <int EpiTiles, class GTensor, class STensor>
|
||||
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
|
||||
: gRow(cute::forward<GTensor>(gRow)),
|
||||
sRow(cute::forward<STensor>(sRow)),
|
||||
params(params) {}
|
||||
|
||||
GTensor gRow; // (CTA_M,CTA_N)
|
||||
STensor sRow; // (CTA_M,CTA_N,PIPE)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
|
||||
if (params.ptr_row == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (issue_tma_load) {
|
||||
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
|
||||
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
|
||||
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
|
||||
// Issue the TMA bulk copy
|
||||
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
|
||||
cute::move(gRow), cute::move(sRow), params);
|
||||
}
|
||||
|
||||
template <int EpiTiles, class RTensor, class STensor>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
|
||||
: tCrRow(cute::forward<RTensor>(tCrRow)),
|
||||
tCsRow(cute::forward<STensor>(tCsRow)),
|
||||
params(params) {}
|
||||
|
||||
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
|
||||
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tCrRow, *(params.ptr_row));
|
||||
return;
|
||||
}
|
||||
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
|
||||
cute::move(tCrRow), cute::move(tCsRow), params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
|
||||
: tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_aligned(filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
|
||||
cute::move(tCgCol), cute::move(tCrCol), params);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
12
csrc/quantization/cutlass_w8a8/common.hpp
Normal file
12
csrc/quantization/cutlass_w8a8/common.hpp
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(status)) \
|
||||
}
|
||||
294
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Normal file
294
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Normal file
@ -0,0 +1,294 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "broadcast_load_epilogue_c2x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This defines a quantized GEMM operation with dequantized output, similar to
|
||||
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Arch, typename ElementAB_, typename ElementD_,
|
||||
typename TileShape, typename WarpShape, typename InstructionShape,
|
||||
int32_t MainLoopStages>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, Int<1>, Int<0>>>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using KernelType =
|
||||
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
float, cutlass::layout::RowMajor, 4,
|
||||
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||
Arch,
|
||||
TileShape, WarpShape, InstructionShape,
|
||||
EVTD,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||
MainLoopStages, Operator,
|
||||
1 /* epilogue stages */
|
||||
>::GemmKernel;
|
||||
// clang-format on
|
||||
|
||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
auto a_scales_ptr = a_scales.data_ptr<float>();
|
||||
auto b_scales_ptr = b_scales.data_ptr<float>();
|
||||
|
||||
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
|
||||
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
|
||||
evt0_compute_args};
|
||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||
|
||||
typename Gemm::EVTD::Arguments epilogue_args{
|
||||
evt1_compute_args,
|
||||
d_args,
|
||||
};
|
||||
|
||||
typename Gemm::Op::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
||||
problem_size, // problem size
|
||||
1, // batch count
|
||||
epilogue_args,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldc};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.get(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 2>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
|
||||
WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
|
||||
WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
}
|
||||
325
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Normal file
325
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Normal file
@ -0,0 +1,325 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "broadcast_load_epilogue_c3x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This defines a quantized GEMM operation with dequantized output, similar to
|
||||
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
||||
typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleBDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, float>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||
EpilogueSchedule, EVTCompute1>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, 16,
|
||||
ElementAB, cutlass::layout::ColumnMajor, 16,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
|
||||
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
|
||||
|
||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
|
||||
args.epilogue.thread = {a_args, {b_args}};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType, int32_t M>
|
||||
struct sm90_fp8_config {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
struct sm90_fp8_config<InType, OutType, 128> {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
struct sm90_fp8_config<InType, OutType, 64> {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
|
||||
return cutlass_scaled_mm_dq_dispatcher<
|
||||
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
75
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Normal file
75
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Normal file
@ -0,0 +1,75 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
int32_t major_capability;
|
||||
int32_t minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
||||
#else
|
||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|
||||
} else if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
137
csrc/quantization/fp8/amd/hip_float8.h
Normal file
137
csrc/quantization/fp8/amd/hip_float8.h
Normal file
@ -0,0 +1,137 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_runtime.h>
|
||||
#else
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
#include "hip_float8_impl.h"
|
||||
|
||||
struct alignas(1) hip_fp8 {
|
||||
struct from_bits_t {};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
uint8_t data;
|
||||
|
||||
hip_fp8() = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||
: data(v) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// NOTE: ON-DEVICE... always optimal bias
|
||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
||||
|
||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
// Host only implementation using s/w simulation
|
||||
explicit HIP_FP8_HOST
|
||||
#else // __HIP__MI300__
|
||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||
explicit HIP_FP8_HOST_DEVICE
|
||||
#endif // __HIP__MI300__
|
||||
hip_fp8(float v) {
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
||||
true /*clip*/>(v);
|
||||
}
|
||||
|
||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// upcast using device specific intrinsic
|
||||
explicit inline HIP_FP8_DEVICE operator float() const {
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(data);
|
||||
|
||||
// upcast
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
||||
: "=v"(fval)
|
||||
: "v"(i32val));
|
||||
|
||||
return fval;
|
||||
}
|
||||
|
||||
explicit inline HIP_FP8_HOST operator float() const
|
||||
#else // __HIP__MI300__
|
||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||
#endif // __HIP__MI300__
|
||||
{
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
||||
data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std {
|
||||
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
||||
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
||||
} // namespace std
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
// all + operator overloading with mixed types
|
||||
// mixed types, always converts to f32, does computation in f32, and returns
|
||||
// float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
||||
return (fa + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
||||
return (float(a) + fb);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
||||
return hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
||||
return a = hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
// overloading multiplication, always returns float,
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
||||
return (a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
||||
return (float(a) * b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
// overloading for compare
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data == b.data);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data != b.data);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
||||
316
csrc/quantization/fp8/amd/hip_float8_impl.h
Normal file
316
csrc/quantization/fp8/amd/hip_float8_impl.h
Normal file
@ -0,0 +1,316 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#define __HIP__MI300__
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||
#define HIP_FP8_HOST __host__
|
||||
#define HIP_FP8_DEVICE __device__
|
||||
#else
|
||||
#define HIP_FP8_HOST_DEVICE
|
||||
#define HIP_FP8_HOST
|
||||
#define HIP_FP8_DEVICE
|
||||
#endif
|
||||
|
||||
namespace hip_fp8_impl {
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
||||
uint8_t i8data;
|
||||
union {
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // NOTE: not endian independent
|
||||
} val;
|
||||
|
||||
uint32_t ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if ((val.i32val & 0x7F800000) !=
|
||||
0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||
false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0];
|
||||
|
||||
return i8data;
|
||||
}
|
||||
#endif // __HIP__MI300__
|
||||
|
||||
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
#endif
|
||||
|
||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
||||
uint32_t rng = 0) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(wm + we == 7, "wm+we==7");
|
||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||
|
||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||
uint32_t x;
|
||||
if (sizeof(T) == 4) {
|
||||
x = reinterpret_cast<uint32_t&>(_x);
|
||||
} else {
|
||||
x = reinterpret_cast<uint16_t&>(_x);
|
||||
}
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
if (sizeof(T) == 4) {
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
} else {
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if (negative_zero_nan) {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return 0x80;
|
||||
}
|
||||
} else {
|
||||
// if(__hisinf(x) || __hisnan(x))
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return 0x80;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
} else {
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||
// need to check whether there is carry and adjust exponent and mantissa again
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent =
|
||||
1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff =
|
||||
f8_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
} else { // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else { // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
||||
// difference for this case, act_exponent could be
|
||||
// larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part
|
||||
and make something not midpoint look like midpoint. For example, the fp16
|
||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||
shift right by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if (exponent_diff > 0) {
|
||||
mantissa >>= exponent_diff;
|
||||
} else if (exponent_diff == -1) {
|
||||
mantissa <<= -exponent_diff;
|
||||
}
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
||||
f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
||||
// that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
||||
drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if (f8_exponent == 0) {
|
||||
if ((1 << mfmt) & mantissa) {
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
} else {
|
||||
if ((1 << (mfmt + 1)) & mantissa) {
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||
if (f8_exponent > max_exp) {
|
||||
if (clip) {
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
} else {
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if (f8_exponent == 0 && mantissa == 0) {
|
||||
return negative_zero_nan ? 0 : (sign << 7);
|
||||
}
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : 8;
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
|
||||
#ifdef __HIPCC__
|
||||
if (is_half) {
|
||||
const uint16_t ihInf = 0x7C00;
|
||||
const uint16_t ihNegInf = 0xFC00;
|
||||
const uint16_t ihNaN = 0x7C01;
|
||||
const uint16_t ihNeg0 = 0x8000;
|
||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||
} else
|
||||
#endif
|
||||
if (is_float) {
|
||||
const uint32_t ifInf = 0x7F800000;
|
||||
const uint32_t ifNegInf = 0xFF800000;
|
||||
const uint32_t ifNaN = 0x7F800001;
|
||||
const uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = reinterpret_cast<const float&>(ifInf);
|
||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||
}
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t sign = x >> 7;
|
||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||
int exponent = (x & 0x7F) >> wm;
|
||||
if (negative_zero_nan) {
|
||||
if (x == 0x80) {
|
||||
return fNaN;
|
||||
}
|
||||
} else {
|
||||
if (x == 0x80) {
|
||||
return fNeg0;
|
||||
}
|
||||
if (exponent == ((1 << we) - 1)) {
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
}
|
||||
}
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||
if (we == 5 && is_half && !negative_zero_nan) {
|
||||
retval = x << 8;
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff =
|
||||
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << wm) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if (exponent <= 0) {
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if (sizeof(T) == 2) {
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||
} else {
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
}
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
} // namespace hip_fp8_impl
|
||||
575
csrc/quantization/fp8/amd/quant_utils.cuh
Normal file
575
csrc/quantization/fp8/amd/quant_utils.cuh
Normal file
@ -0,0 +1,575 @@
|
||||
#pragma once
|
||||
#include "hip_float8.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../attention/dtype_fp8.cuh"
|
||||
#include "../../../attention/dtype_float32.cuh"
|
||||
#include "../../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
namespace vllm {
|
||||
#ifdef USE_ROCM
|
||||
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||
const float scale) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0];
|
||||
tmp.h2r.y.data = f2[1];
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0];
|
||||
res.y = f2[1];
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_
|
||||
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||
hip_fp8 res{__bfloat162float(a)};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||
hip_fp8 f8(a);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, float2>(const float2& a) {
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
// Float4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
// float2 -> bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> bfloat162x2
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
||||
bf16_4_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> bfloat162x4
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||
bf16_8_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
b.z = __float22bfloat162_rn(a.z);
|
||||
b.w = __float22bfloat162_rn(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||
precision domains
|
||||
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization(
|
||||
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
||||
scale => HP
|
||||
|
||||
*/
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8) * scale;
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0] * scale;
|
||||
tmp.h2r.y.data = f2[1] * scale;
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] =
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
||||
static_cast<uint8_t>(a >> 8U), scale);
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4
|
||||
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
||||
const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f * scale);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
const float scale) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y =
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t& a, const float scale) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0] * scale;
|
||||
res.y = f2[1] * scale;
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
||||
scale);
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_
|
||||
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Quantize(HP / scale) => FP8 */
|
||||
|
||||
// TODO(Hai): vectorized to add
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, const float scale) {
|
||||
hip_fp8 res{__bfloat162float(a) / scale};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
||||
hip_fp8 f8(a / scale);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout convert(const Tin& x) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return vec_conversion<Tout, Tin>(x);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace fp8
|
||||
#endif // USE_ROCM
|
||||
} // namespace vllm
|
||||
@ -1,167 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_runtime.h>
|
||||
#else
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
#include "hip_float8_impl.h"
|
||||
|
||||
struct alignas(1) hip_fp8
|
||||
{
|
||||
struct from_bits_t
|
||||
{
|
||||
};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
||||
uint8_t data;
|
||||
|
||||
hip_fp8() = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||
: data(v)
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// NOTE: ON-DEVICE... always optimal bias
|
||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
||||
{
|
||||
}
|
||||
|
||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
|
||||
// Host only implementation using s/w simulation
|
||||
explicit HIP_FP8_HOST
|
||||
#else // __HIP__MI300__
|
||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||
explicit HIP_FP8_HOST_DEVICE
|
||||
#endif // __HIP__MI300__
|
||||
hip_fp8(float v)
|
||||
{
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
||||
}
|
||||
|
||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// upcast using device specific intrinsic
|
||||
explicit inline HIP_FP8_DEVICE operator float() const
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(data);
|
||||
|
||||
// upcast
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
|
||||
return fval;
|
||||
}
|
||||
|
||||
explicit inline HIP_FP8_HOST operator float() const
|
||||
#else // __HIP__MI300__
|
||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||
#endif // __HIP__MI300__
|
||||
{
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std
|
||||
{
|
||||
inline hip_fp8 sin(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(sinf(float(a)));
|
||||
}
|
||||
inline hip_fp8 cos(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(cosf(float(a)));
|
||||
}
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
||||
{
|
||||
return a;
|
||||
}
|
||||
} // namespace std
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
||||
{
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
// all + operator overloading with mixed types
|
||||
// mixed types, always converts to f32, does computation in f32, and returns float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
||||
{
|
||||
return (fa + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
||||
{
|
||||
return (float(a) + fb);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
||||
{
|
||||
return a = hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
// overloading multiplication, always returns float,
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
||||
{
|
||||
return (a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
||||
{
|
||||
return (float(a) * b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
||||
{
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
||||
{
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
// overloading for compare
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return (a.data == b.data);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return (a.data != b.data);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
||||
@ -1,316 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#define __HIP__MI300__
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||
#define HIP_FP8_HOST __host__
|
||||
#define HIP_FP8_DEVICE __device__
|
||||
#else
|
||||
#define HIP_FP8_HOST_DEVICE
|
||||
#define HIP_FP8_HOST
|
||||
#define HIP_FP8_DEVICE
|
||||
#endif
|
||||
|
||||
namespace hip_fp8_impl
|
||||
{
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
||||
{
|
||||
uint8_t i8data;
|
||||
union {
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // NOTE: not endian independent
|
||||
} val;
|
||||
|
||||
uint32_t ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||
false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0];
|
||||
|
||||
return i8data;
|
||||
}
|
||||
#endif // __HIP__MI300__
|
||||
|
||||
HIP_FP8_HOST inline int clz(uint32_t x)
|
||||
{
|
||||
return __builtin_clz(x);
|
||||
}
|
||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
||||
{
|
||||
return __clz(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
||||
{
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(wm + we == 7, "wm+we==7");
|
||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||
|
||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||
uint32_t x;
|
||||
if (sizeof(T) == 4) {
|
||||
x = reinterpret_cast<uint32_t&>(_x);
|
||||
} else {
|
||||
x = reinterpret_cast<uint16_t&>(_x);
|
||||
}
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
if (sizeof(T) == 4) {
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
} else {
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if (negative_zero_nan) {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return 0x80;
|
||||
}
|
||||
} else {
|
||||
// if(__hisinf(x) || __hisnan(x))
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return 0x80;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
} else {
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||
// need to check whether there is carry and adjust exponent and mantissa again
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
} else { // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else { // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
||||
// for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part
|
||||
and make something not midpoint look like midpoint. For example, the fp16
|
||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||
shift right by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if (exponent_diff > 0) {
|
||||
mantissa >>= exponent_diff;
|
||||
} else if (exponent_diff == -1) {
|
||||
mantissa <<= -exponent_diff;
|
||||
}
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
||||
// is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if (f8_exponent == 0) {
|
||||
if ((1 << mfmt) & mantissa) {
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
} else {
|
||||
if ((1 << (mfmt + 1)) & mantissa) {
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||
if (f8_exponent > max_exp) {
|
||||
if (clip) {
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
} else {
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if (f8_exponent == 0 && mantissa == 0) {
|
||||
return negative_zero_nan ? 0 : (sign << 7);
|
||||
}
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
||||
{
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : 8;
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
|
||||
#ifdef __HIPCC__
|
||||
if (is_half) {
|
||||
const uint16_t ihInf = 0x7C00;
|
||||
const uint16_t ihNegInf = 0xFC00;
|
||||
const uint16_t ihNaN = 0x7C01;
|
||||
const uint16_t ihNeg0 = 0x8000;
|
||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||
} else
|
||||
#endif
|
||||
if (is_float) {
|
||||
const uint32_t ifInf = 0x7F800000;
|
||||
const uint32_t ifNegInf = 0xFF800000;
|
||||
const uint32_t ifNaN = 0x7F800001;
|
||||
const uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = reinterpret_cast<const float&>(ifInf);
|
||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||
}
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t sign = x >> 7;
|
||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||
int exponent = (x & 0x7F) >> wm;
|
||||
if (negative_zero_nan) {
|
||||
if (x == 0x80) {
|
||||
return fNaN;
|
||||
}
|
||||
} else {
|
||||
if (x == 0x80) {
|
||||
return fNeg0;
|
||||
}
|
||||
if (exponent == ((1 << we) - 1)) {
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
}
|
||||
}
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||
if (we == 5 && is_half && !negative_zero_nan) {
|
||||
retval = x << 8;
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << wm) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if (exponent <= 0) {
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if (sizeof(T) == 2) {
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||
} else {
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
}
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
} // namespace hip_fp8_impl
|
||||
@ -1,517 +0,0 @@
|
||||
#pragma once
|
||||
#include "hip_float8.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../attention/dtype_float32.cuh"
|
||||
#include "../../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
namespace vllm
|
||||
{
|
||||
namespace fp8_e4m3 {
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0];
|
||||
tmp.h2r.y.data = f2[1];
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0];
|
||||
res.y = f2[1];
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||
{
|
||||
hip_fp8 res{__bfloat162float(a)};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
hip_fp8 f8(a);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
// Float4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
// float2 -> bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
|
||||
{
|
||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> bfloat162x2
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
||||
{
|
||||
bf16_4_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> bfloat162x4
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
||||
{
|
||||
bf16_8_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
b.z = __float22bfloat162_rn(a.z);
|
||||
b.w = __float22bfloat162_rn(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
||||
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
||||
s.t.
|
||||
Quantize(HP / scale) => FP8
|
||||
Dequant(FP8) * scale => HP
|
||||
|
||||
*/
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8) * scale;
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0] * scale;
|
||||
tmp.h2r.y.data = f2[1] * scale;
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f * scale);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0] * scale;
|
||||
res.y = f2[1] * scale;
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
/* Quantize(HP / scale) => FP8 */
|
||||
|
||||
// TODO(Hai): vectorized to add
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
|
||||
{
|
||||
hip_fp8 res{__bfloat162float(a)/scale};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8(a/scale);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace vllm
|
||||
124
csrc/quantization/fp8/common.cu
Normal file
124
csrc/quantization/fp8/common.cu
Normal file
@ -0,0 +1,124 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
const scalar_t val, const float scale) {
|
||||
float x = static_cast<float>(val) / scale;
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||
// reduction tree and the memory in scale is atomically updated.
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = max(tmp, fabs(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now perform parallel reduction within the thread block
|
||||
int ib = blockDim.x / 2;
|
||||
while (ib != 0) {
|
||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||
}
|
||||
__syncthreads();
|
||||
ib /= 2;
|
||||
}
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale,
|
||||
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
while (i < num_elems) {
|
||||
out[i] = scaled_fp8_conversion(input[i], *scale);
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
@ -1,126 +0,0 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
||||
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||
// reduction tree and the memory in scale is atomically updated.
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template<typename scalar_t>
|
||||
__global__ void segmented_max_reduction(
|
||||
float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = max(tmp, fabs(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now perform parallel reduction within the thread block
|
||||
int ib = blockDim.x / 2;
|
||||
while (ib != 0) {
|
||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||
}
|
||||
__syncthreads();
|
||||
ib /= 2;
|
||||
}
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
while (i < num_elems) {
|
||||
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"scaled_fp8_quant_kernel",
|
||||
[&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"scaled_fp8_quant_kernel",
|
||||
[&] {
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
570
csrc/quantization/fp8/nvidia/quant_utils.cuh
Normal file
570
csrc/quantization/fp8/nvidia/quant_utils.cuh
Normal file
@ -0,0 +1,570 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../../attention/attention_dtypes.h"
|
||||
#include <assert.h>
|
||||
#include <float.h>
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace vllm {
|
||||
#ifndef USE_ROCM
|
||||
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout
|
||||
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
||||
tmp.u32[1] =
|
||||
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
||||
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
||||
res.y =
|
||||
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float
|
||||
vec_conversion<float, uint8_t>(const uint8_t &a,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
||||
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
||||
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
||||
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
||||
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
||||
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
||||
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||
precision domains Convention of the scale in API, e.g: FP8_data =
|
||||
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
||||
Dequant(FP8) * scale => HP
|
||||
*/
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(
|
||||
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
return float_to_half(half_to_float(tmp.x) * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
||||
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale, fp8_type);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4
|
||||
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
||||
fp8_type);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
||||
scale, fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
||||
fp8_type);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale, fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
||||
const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
uint16_t tmp = res.x;
|
||||
|
||||
// half -> float
|
||||
return half_to_float(tmp) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
||||
fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
||||
const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||
__NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
||||
const float& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout convert(const Tin& x) {
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace fp8
|
||||
#endif // not USE_ROCM
|
||||
} // namespace vllm
|
||||
@ -1,277 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
#include "../../attention/attention_dtypes.h"
|
||||
#include "../../attention/dtype_float32.cuh"
|
||||
#include "../../attention/dtype_float16.cuh"
|
||||
#include "../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
namespace fp8_e5m2_unscaled {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template<>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template<>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template<>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template<>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
// half -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
||||
__nv_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
} // namespace fp8_e5m2_unscaled
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
} // namespace vllm
|
||||
@ -9,54 +9,54 @@ namespace vllm {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||
unsigned int* address_as_ui =
|
||||
(unsigned int*)((char*)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
do {
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)
|
||||
: (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||
atomicAdd_half(address, val);
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||
atomicAdd_half2(address, val);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
||||
Adapted from https://github.com/turboderp/exllamav2 and
|
||||
https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
@ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
class MatrixView_half {
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||
return __half2half2(data[row * width + column]);
|
||||
}
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row,
|
||||
int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row,
|
||||
int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row,
|
||||
int column) const {
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
class MatrixView_half_rw {
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
__device__ __forceinline__ half item(int row, int column) const {
|
||||
return data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||
return ((half2*)data)[(row * width + column) / 2];
|
||||
}
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||
return &data[row * width + column];
|
||||
}
|
||||
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||
data[row * width + column] = value;
|
||||
}
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||
((half2*)data)[(row * width + column) / 2] = value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1,
|
||||
half v2, half v3) {
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
class MatrixView_q4_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
class MatrixView_q4_column {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||
return data[row / 8 * width + column];
|
||||
}
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row,
|
||||
int column) {
|
||||
return &data[row / 8 * width + column];
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q2_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
class MatrixView_q2_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x0f) * 2;
|
||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||
}
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
}
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
items[2] = (d >> 4) & 0x03;
|
||||
items[3] = (d >> 6) & 0x03;
|
||||
}
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
items[2] = (d >> 4) & 0x03;
|
||||
items[3] = (d >> 6) & 0x03;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q3_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
class MatrixView_q3_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int z_w = column * 3 / 32;
|
||||
int z_mod = column & 0x1f;
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int z_w = column * 3 / 32;
|
||||
int z_mod = column & 0x1f;
|
||||
|
||||
if (z_mod == 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
} else if (z_mod == 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
} else if (z_mod < 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||
} else if (z_mod < 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||
} else {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||
}
|
||||
if (z_mod == 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) |
|
||||
((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
} else if (z_mod == 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) |
|
||||
((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
} else if (z_mod < 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||
} else if (z_mod < 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||
} else {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x1f);
|
||||
uint32_t d;
|
||||
if (shift <= 4) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||
} else if (shift == 8) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
} else if (shift <= 16) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||
} else if (shift == 20) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
} else {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||
}
|
||||
items[0] = d & 0x07;
|
||||
items[1] = (d >> 3) & 0x07;
|
||||
items[2] = (d >> 6) & 0x07;
|
||||
items[3] = (d >> 9) & 0x07;
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x1f);
|
||||
uint32_t d;
|
||||
if (shift <= 4) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||
} else if (shift == 8) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
} else if (shift <= 16) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||
} else if (shift == 20) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
} else {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||
}
|
||||
items[0] = d & 0x07;
|
||||
items[1] = (d >> 3) & 0x07;
|
||||
items[2] = (d >> 6) & 0x07;
|
||||
items[3] = (d >> 9) & 0x07;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q8_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
class MatrixView_q8_row {
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data,
|
||||
const int height,
|
||||
const int width)
|
||||
: data(data), height(height), width(width) {}
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x03) * 8;
|
||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||
}
|
||||
__device__ __forceinline__ int item(int row, int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x03) * 8;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
}
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x03) * 8;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x03) * 2;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
items[2] = (d >> 16) & 0xff;
|
||||
items[3] = (d >> 24) & 0xff;
|
||||
}
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||
int column) const {
|
||||
int shift = (column & 0x03) * 2;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
items[2] = (d >> 16) & 0xff;
|
||||
items[3] = (d >> 24) & 0xff;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gptq
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -14,71 +14,60 @@ namespace gptq {
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0,
|
||||
half2 (&dq)[8], int stride,
|
||||
const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z4 = __half2half2(z4_);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
const half2 z64 = __half2half2(z64_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z4 = __half2half2(z4_);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
const half2 z64 = __half2half2(z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
|
||||
@ -11,128 +11,136 @@ namespace gptq {
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qa & 0x07;
|
||||
uint32_t t1 = (qa & 0x38) >> 3;
|
||||
qa >>= 6;
|
||||
za |= (t0 << (i * 3));
|
||||
za |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qb & 0x07;
|
||||
uint32_t t1 = (qb & 0x38) >> 3;
|
||||
qb >>= 6;
|
||||
zb |= (t0 << (i * 3));
|
||||
zb |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t t0 = qc & 0x07;
|
||||
uint32_t t1 = (qc & 0x38) >> 3;
|
||||
qc >>= 6;
|
||||
zc |= (t0 << (i * 3));
|
||||
zc |= (t1 << (i * 3 + 16));
|
||||
}
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16], int stride,
|
||||
const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y8, z8);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y8, z8);
|
||||
dq[4] = __hfma2(q4.as_half2, y64, z64);
|
||||
dq[5] = __hadd2(q5.as_half2, z1);
|
||||
dq[6] = __hfma2(q6.as_half2, y8, z8);
|
||||
dq[7] = __hadd2(q7.as_half2, z1);
|
||||
dq[8] = __hfma2(q8.as_half2, y8, z8);
|
||||
dq[9] = __hfma2(q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
|
||||
@ -13,133 +13,112 @@ namespace gptq {
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0,
|
||||
half2 (&dq)[4], int stride,
|
||||
const uint32_t zero) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale(
|
||||
const uint32_t zero, const half scale, half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]) {
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride, bool scaled) {
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) |
|
||||
c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) |
|
||||
c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) |
|
||||
c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) |
|
||||
c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled) {
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0],
|
||||
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
} else {
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1],
|
||||
z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
|
||||
|
||||
__forceinline__ __device__ void dequant_8bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||
__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4], int stride,
|
||||
const uint32_t zero) {
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
for (int i = 0; i < 4; i++)
|
||||
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
|
||||
@ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
union half2_uint32 {
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
union half_uint16 {
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero,
|
||||
const half scale) {
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
|
||||
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift,
|
||||
const int mask) {
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0,
|
||||
const int shift, const int mask) {
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -11,22 +11,23 @@
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
||||
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
||||
// many registers per warp and small tiles.
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
// we want relatively few warps to have many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
static constexpr int pipe_stages =
|
||||
4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
@ -35,30 +36,35 @@ using I4 = Vec<int, 4>;
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
@ -67,4 +73,4 @@ __device__ inline void cp_async_wait() {
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace gptq_marlin
|
||||
|
||||
77
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
Normal file
77
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
Normal file
@ -0,0 +1,77 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include "gptq_marlin.cuh"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
using FragA = Vec<nv_bfloat162, 4>;
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||
const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace gptq_marlin
|
||||
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user