mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-21 23:48:57 +08:00 
			
		
		
		
	Compare commits
	
		
			195 Commits
		
	
	
		
			v0.4.2
			...
			optimize-p
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d5bf492f16 | |||
| f775a07e30 | |||
| 4f0d17c05c | |||
| 10c38e3e46 | |||
| cafb8e06c5 | |||
| cbb2f59cc8 | |||
| 0ab278ca31 | |||
| 7a64d24aad | |||
| 8c7bab79f5 | |||
| dfbe60dc62 | |||
| a66cf40b20 | |||
| f790ad3c50 | |||
| ed59a7ed23 | |||
| 1936d7bab0 | |||
| 996cf2de5c | |||
| 044793d8df | |||
| c2d6d2f960 | |||
| 8279078e21 | |||
| b9c0605a8e | |||
| 37464a0f74 | |||
| c354072828 | |||
| f081c3ce4b | |||
| 260d119e86 | |||
| a360ff80bb | |||
| 1197e02141 | |||
| 657579113f | |||
| e9899fb7a4 | |||
| a377f0bd5e | |||
| e9d3aa04f6 | |||
| a22dea54d3 | |||
| 533c217792 | |||
| 6d21fa1cad | |||
| b35be5403f | |||
| 45a1a69b98 | |||
| 87a658c812 | |||
| 429d89720e | |||
| a9bcc7afb2 | |||
| d79d9eaaff | |||
| f758505c73 | |||
| d910816c73 | |||
| 87d41c849d | |||
| e07aff9e52 | |||
| 5bf185a1c4 | |||
| 4fbcb0f27e | |||
| 7c3604fb68 | |||
| b1c255630d | |||
| eb6c50cdc2 | |||
| eecd864388 | |||
| ae495c74ea | |||
| 4238bc82f2 | |||
| 594392d27a | |||
| 18c1f16d86 | |||
| 5bd3c65072 | |||
| 616e600e0b | |||
| dfba529b40 | |||
| 5ae5ed1e60 | |||
| 290f4ada2b | |||
| dd8de11f0a | |||
| 9ba415588a | |||
| d4f3985907 | |||
| 890aa93d27 | |||
| fbdb7b3ee2 | |||
| 1102bef219 | |||
| f17a1a8f96 | |||
| d5a1697772 | |||
| 325c119961 | |||
| 8e192ff967 | |||
| e64fde4b01 | |||
| 919770957f | |||
| 6a50f4cafa | |||
| e3470f8753 | |||
| a1242324c9 | |||
| 5eda2ea02a | |||
| 2ba80bed27 | |||
| 6066253296 | |||
| ee3eea0a1b | |||
| a36de682d4 | |||
| eb6d3c264d | |||
| 97b030005c | |||
| a3a73ab069 | |||
| 8674f9880e | |||
| c74c913bfb | |||
| 5f6d10c14c | |||
| 9b9a10d6cb | |||
| 99eff67ba9 | |||
| 14772eeb8e | |||
| 757b62c495 | |||
| e941f88584 | |||
| f12c3b5b3d | |||
| d130b573a0 | |||
| 65ae8c2c8f | |||
| c3af44722c | |||
| 1937e29848 | |||
| f0eecee610 | |||
| 943e72ca56 | |||
| 546a97ef69 | |||
| da5a0b539d | |||
| 6287537a0c | |||
| b57e6c5949 | |||
| 27ce85476e | |||
| f68470e803 | |||
| 2e9a2227ec | |||
| c0724fc915 | |||
| 86b45ae065 | |||
| c5711ef985 | |||
| 48d5985a08 | |||
| 33e0823de5 | |||
| 26148120b3 | |||
| 0150a10630 | |||
| 8e7fb5d43a | |||
| 9a31a817a8 | |||
| 2060e93659 | |||
| 8435b207af | |||
| 10fa9eea21 | |||
| e08188081b | |||
| b5853f9963 | |||
| f09edd8a25 | |||
| 6979ade384 | |||
| 9216b9cc38 | |||
| 5e0391c040 | |||
| dbc0754ddf | |||
| 99caa49106 | |||
| 5c342570d7 | |||
| 973617ae02 | |||
| 30e754390c | |||
| 52f8107cf2 | |||
| fc0d9dfc3a | |||
| 361c461a12 | |||
| a5675d348b | |||
| e9cdd2b1e2 | |||
| 65bf2ac165 | |||
| 8a7cc254a0 | |||
| 29bc01bf3b | |||
| 676a99982f | |||
| dc72402b57 | |||
| ccb63a8245 | |||
| c579b750a0 | |||
| 4bfa7e7f75 | |||
| ac1fbf7fd2 | |||
| 33d3914b1e | |||
| 1356df53bd | |||
| ce532ff45c | |||
| 8bc68e198c | |||
| 0fca3cdcf2 | |||
| e7c46b9527 | |||
| 350f9e107f | |||
| 702bee461f | |||
| a7be4d0072 | |||
| a709e87a4f | |||
| 6eaccb7353 | |||
| e254497b66 | |||
| 4e12131089 | |||
| fcc2994be6 | |||
| 2e7796f2cf | |||
| 706588a77d | |||
| 6a0f617210 | |||
| dac6a3f6ed | |||
| 64b77dfd7e | |||
| 51d4094fda | |||
| e965d46184 | |||
| 208b71bcc1 | |||
| c833101740 | |||
| 379da6dcb5 | |||
| ebce310b74 | |||
| be0c5180ac | |||
| cea64430f6 | |||
| a3c124570a | |||
| ff5abcd746 | |||
| 0ee535b294 | |||
| 190bc838e1 | |||
| f12b20decc | |||
| 16bc0a098f | |||
| e288df0632 | |||
| 8b9241be3a | |||
| f942efb5a3 | |||
| 89579a201f | |||
| 230c4b38c1 | |||
| 20cfcdec99 | |||
| ad932a221d | |||
| 5510cf0e8a | |||
| 0f9a6e3d22 | |||
| f6a593093a | |||
| d7740ea4dc | |||
| cc466a3290 | |||
| 8344f7742b | |||
| 469f85c782 | |||
| 10760da800 | |||
| 478aed5827 | |||
| 63575bc2e1 | |||
| a98187cf72 | |||
| bd99d22629 | |||
| 19cb4716ee | |||
| e186d37cb1 | |||
| 323f27b904 | |||
| 0650e5935b | 
| @ -1,7 +1,7 @@ | ||||
| import os | ||||
| import zipfile | ||||
|  | ||||
| MAX_SIZE_MB = 100 | ||||
| MAX_SIZE_MB = 200 | ||||
|  | ||||
|  | ||||
| def print_top_10_largest_files(zip_file): | ||||
|  | ||||
| @ -1,10 +1,38 @@ | ||||
| # This script build the ROCm docker image and runs test inside it. | ||||
| # This script runs test inside the corresponding ROCm docker container. | ||||
| set -ex | ||||
|  | ||||
| # Print ROCm version | ||||
| echo "--- ROCm info" | ||||
| rocminfo | ||||
|  | ||||
| # cleanup older docker images | ||||
| cleanup_docker() { | ||||
|   # Get Docker's root directory | ||||
|   docker_root=$(docker info -f '{{.DockerRootDir}}') | ||||
|   if [ -z "$docker_root" ]; then | ||||
|     echo "Failed to determine Docker root directory." | ||||
|     exit 1 | ||||
|   fi | ||||
|   echo "Docker root directory: $docker_root" | ||||
|   # Check disk usage of the filesystem where Docker's root directory is located | ||||
|   disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') | ||||
|   # Define the threshold | ||||
|   threshold=70 | ||||
|   if [ "$disk_usage" -gt "$threshold" ]; then | ||||
|     echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." | ||||
|     # Remove dangling images (those that are not tagged and not used by any container) | ||||
|     docker image prune -f | ||||
|     # Remove unused volumes | ||||
|     docker volume prune -f | ||||
|     echo "Docker images and volumes cleanup completed." | ||||
|   else | ||||
|     echo "Disk usage is below $threshold%. No cleanup needed." | ||||
|   fi | ||||
| } | ||||
|  | ||||
| # Call the cleanup docker function | ||||
| cleanup_docker | ||||
|  | ||||
| echo "--- Resetting GPUs" | ||||
|  | ||||
| echo "reset" > /opt/amdgpu/etc/gpu_state | ||||
| @ -19,15 +47,16 @@ done | ||||
|  | ||||
| echo "--- Building container" | ||||
| sha=$(git rev-parse --short HEAD) | ||||
| container_name=rocm_${sha} | ||||
| image_name=rocm_${sha} | ||||
| container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo) | ||||
| docker build \ | ||||
|         -t ${container_name} \ | ||||
|         -t ${image_name} \ | ||||
|         -f Dockerfile.rocm \ | ||||
|         --progress plain \ | ||||
|         . | ||||
|  | ||||
| remove_docker_container() { | ||||
|    docker rm -f ${container_name} || docker image rm -f ${container_name} || true | ||||
|    docker rm -f ${container_name} || docker image rm -f ${image_name} || true | ||||
| } | ||||
| trap remove_docker_container EXIT | ||||
|  | ||||
| @ -39,6 +68,6 @@ docker run \ | ||||
|         --rm \ | ||||
|         -e HF_TOKEN \ | ||||
|         --name ${container_name} \ | ||||
|         ${container_name} \ | ||||
|         /bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//") | ||||
|         ${image_name} \ | ||||
|         /bin/bash -c "${@}" | ||||
|  | ||||
|  | ||||
| @ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.." | ||||
| (which wget && which curl) || (apt-get update && apt-get install -y wget curl) | ||||
|  | ||||
| # run python-based benchmarks and upload the result to buildkite | ||||
| python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt | ||||
| python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt | ||||
| bench_latency_exit_code=$? | ||||
|  | ||||
| python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt | ||||
| python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt | ||||
| bench_throughput_exit_code=$? | ||||
|  | ||||
| # run server-based benchmarks and upload the result to buildkite | ||||
| @ -74,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then | ||||
|     exit $bench_serving_exit_code | ||||
| fi | ||||
|  | ||||
| /workspace/buildkite-agent artifact upload openai-*.json | ||||
| rm ShareGPT_V3_unfiltered_cleaned_split.json | ||||
| /workspace/buildkite-agent artifact upload "*.json" | ||||
|  | ||||
| @ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; } | ||||
| trap remove_docker_container EXIT | ||||
| remove_docker_container | ||||
|  | ||||
| # Run the image and launch offline inference | ||||
| docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py | ||||
| # Run the image | ||||
| docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test | ||||
|  | ||||
| # offline inference | ||||
| docker exec cpu-test bash -c "python3 examples/offline_inference.py" | ||||
|  | ||||
| # Run basic model test | ||||
| docker exec cpu-test bash -c "cd tests; | ||||
|   pip install pytest Pillow protobuf | ||||
|   bash ../.buildkite/download-images.sh | ||||
|   cd ../ | ||||
|   pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" | ||||
|  | ||||
| @ -5,13 +5,16 @@ | ||||
|  | ||||
| steps: | ||||
| - label: Regression Test | ||||
|   mirror_hardwares: [amd] | ||||
|   command: pytest -v -s test_regression.py | ||||
|   working_dir: "/vllm-workspace/tests" # optional | ||||
|  | ||||
| - label: AsyncEngine Test | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: pytest -v -s async_engine | ||||
|  | ||||
| - label: Basic Correctness Test | ||||
|   mirror_hardwares: [amd] | ||||
|   commands: | ||||
|   - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py | ||||
|   - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py | ||||
| @ -24,59 +27,68 @@ steps: | ||||
|   command: pytest -v -s core | ||||
|  | ||||
| - label: Distributed Comm Ops Test | ||||
|   command: pytest -v -s test_comm_ops.py | ||||
|   working_dir: "/vllm-workspace/tests/distributed" | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: pytest -v -s distributed/test_comm_ops.py | ||||
|   working_dir: "/vllm-workspace/tests" | ||||
|   num_gpus: 2 | ||||
|  | ||||
| - label: Distributed Tests | ||||
|   working_dir: "/vllm-workspace/tests/distributed" | ||||
|  | ||||
|   num_gpus: 2 # only support 1 or 2 for now. | ||||
|   mirror_hardwares: [amd] | ||||
|  | ||||
|   working_dir: "/vllm-workspace/tests" | ||||
|   num_gpus: 2 | ||||
|   commands: | ||||
|   - pytest -v -s test_pynccl_library.py | ||||
|   - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py | ||||
|   - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py | ||||
|   - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py | ||||
|   - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py | ||||
|   - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py | ||||
|   - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py | ||||
|   - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py | ||||
|   - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py | ||||
|   - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py | ||||
|   - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py | ||||
|   - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py | ||||
|   - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py | ||||
|   - pytest -v -s spec_decode/e2e/test_integration_dist.py  | ||||
|  | ||||
| - label: Distributed Tests (Multiple Groups) | ||||
|   working_dir: "/vllm-workspace/tests/distributed" | ||||
|   #mirror_hardwares: [amd] | ||||
|   working_dir: "/vllm-workspace/tests" | ||||
|   num_gpus: 4 | ||||
|   commands: | ||||
|   - pytest -v -s test_pynccl.py | ||||
|   - pytest -v -s distributed/test_pynccl.py | ||||
|  | ||||
| - label: Engine Test | ||||
|   mirror_hardwares: [amd] | ||||
|   command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py | ||||
|  | ||||
| - label: Entrypoints Test | ||||
|   mirror_hardwares: [amd] | ||||
|  | ||||
|   commands: | ||||
|   # these tests have to be separated, because each one will allocate all posible GPU memory | ||||
|   - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py | ||||
|   - pytest -v -s entrypoints/test_server_oot_registration.py | ||||
|   - pytest -v -s test_inputs.py | ||||
|   - pytest -v -s entrypoints -m llm | ||||
|   - pytest -v -s entrypoints -m openai | ||||
|  | ||||
| - label: Examples Test | ||||
|   working_dir: "/vllm-workspace/examples" | ||||
|   mirror_hardwares: [amd] | ||||
|   commands: | ||||
|     # install aws cli for llava_example.py | ||||
|     - pip install awscli | ||||
|     # install tensorizer for tensorize_vllm_model.py | ||||
|     - pip install awscli tensorizer | ||||
|     - python3 offline_inference.py | ||||
|     - python3 offline_inference_with_prefix.py | ||||
|     - python3 llm_engine_example.py | ||||
|     - python3 llava_example.py | ||||
|     - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors | ||||
|  | ||||
| - label: Kernels Test %N | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT | ||||
|   parallelism: 4 | ||||
|  | ||||
| - label: Models Test | ||||
|   mirror_hardwares: [amd] | ||||
|   #mirror_hardwares: [amd] | ||||
|   commands: | ||||
|     - bash ../.buildkite/download-images.sh | ||||
|     - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py | ||||
|     - pytest -v -s models --ignore=models/test_llava.py | ||||
|  | ||||
| - label: Llava Test | ||||
|   mirror_hardwares: [amd] | ||||
| @ -90,31 +102,53 @@ steps: | ||||
|     - pytest -v -s prefix_caching | ||||
|  | ||||
| - label: Samplers Test | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: pytest -v -s samplers | ||||
|  | ||||
| - label: LogitsProcessor Test | ||||
|   mirror_hardwares: [amd] | ||||
|   command: pytest -v -s test_logits_processor.py | ||||
|  | ||||
| - label: Utils Test | ||||
|   command: pytest -v -s test_utils.py | ||||
|  | ||||
| - label: Worker Test | ||||
|   mirror_hardwares: [amd] | ||||
|   command: pytest -v -s worker | ||||
|  | ||||
| - label: Speculative decoding tests | ||||
|   mirror_hardwares: [amd] | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: pytest -v -s spec_decode | ||||
|  | ||||
| - label: LoRA Test %N | ||||
|   command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py | ||||
|   parallelism: 4 | ||||
|  | ||||
| - label: LoRA Long Context (Distributed) | ||||
|   #mirror_hardwares: [amd] | ||||
|   num_gpus: 4 | ||||
|   # This test runs llama 13B, so it is required to run on 4 GPUs. | ||||
|   commands: | ||||
|     # Temporarily run this way because we cannot clean up GPU mem usage | ||||
|     # for multi GPU tests. | ||||
|     # TODO(sang): Fix it. | ||||
|     - pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced | ||||
|     - pytest -v -s lora/test_long_context.py::test_batched_rope_kernel | ||||
|     - pytest -v -s lora/test_long_context.py::test_self_consistency | ||||
|     - pytest -v -s lora/test_long_context.py::test_quality | ||||
|     - pytest -v -s lora/test_long_context.py::test_max_len | ||||
|  | ||||
| - label: Tensorizer Test | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader | ||||
|  | ||||
| - label: Metrics Test | ||||
|   mirror_hardwares: [amd] | ||||
|   command: pytest -v -s metrics | ||||
|  | ||||
| - label: Quantization Test | ||||
|   #mirror_hardwares: [amd] | ||||
|   command: pytest -v -s quantization | ||||
|  | ||||
| - label: Benchmarks | ||||
|  | ||||
							
								
								
									
										59
									
								
								.buildkite/test-template-aws.j2
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								.buildkite/test-template-aws.j2
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,59 @@ | ||||
| {% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %} | ||||
| {% set default_working_dir = "/vllm-workspace/tests" %} | ||||
|  | ||||
| steps: | ||||
|   - label: ":docker: build image" | ||||
|     agents: | ||||
|       queue: cpu_queue | ||||
|     commands: | ||||
|       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" | ||||
|       - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." | ||||
|       - "docker push {{ docker_image }}" | ||||
|     env: | ||||
|       DOCKER_BUILDKIT: "1" | ||||
|     retry: | ||||
|       automatic: | ||||
|         - exit_status: -1  # Agent was lost | ||||
|           limit: 5 | ||||
|         - exit_status: -10  # Agent was lost | ||||
|           limit: 5 | ||||
|   - wait | ||||
|  | ||||
|   {% for step in steps %} | ||||
|   - label: "{{ step.label }}" | ||||
|     agents: | ||||
|       {% if step.no_gpu %} | ||||
|       queue: cpu_queue | ||||
|       {% elif step.num_gpus == 2 or step.num_gpus == 4 %} | ||||
|       queue: gpu_4_queue | ||||
|       {% else %} | ||||
|       queue: gpu_1_queue | ||||
|       {% endif %} | ||||
|     soft_fail: true | ||||
|     {% if step.parallelism %} | ||||
|     parallelism: {{ step.parallelism }} | ||||
|     {% endif %} | ||||
|     retry: | ||||
|       automatic: | ||||
|         - exit_status: -1  # Agent was lost | ||||
|           limit: 5 | ||||
|         - exit_status: -10  # Agent was lost | ||||
|           limit: 5 | ||||
|     plugins: | ||||
|       - docker#v5.2.0: | ||||
|           image: {{ docker_image }} | ||||
|           always-pull: true | ||||
|           propagate-environment: true | ||||
|           {% if not step.no_gpu %} | ||||
|           gpus: all | ||||
|           {% endif %} | ||||
|           command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe  }} && {{ step.command  or (step.commands | join(' && ')) | safe }}"] | ||||
|           environment: | ||||
|             - VLLM_USAGE_SOURCE=ci-test | ||||
|             - HF_TOKEN | ||||
|             {% if step.label == "Speculative decoding tests" %} | ||||
|             - VLLM_ATTENTION_BACKEND=XFORMERS | ||||
|             {% endif %} | ||||
|           volumes: | ||||
|             - /dev/shm:/dev/shm | ||||
|   {% endfor %} | ||||
| @ -3,9 +3,8 @@ | ||||
| {% set default_working_dir = "/vllm-workspace/tests" %} | ||||
|  | ||||
| steps: | ||||
|  | ||||
|   - label: ":docker: build image" | ||||
|     commands: | ||||
|     commands:  | ||||
|       - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." | ||||
|       - "docker push {{ docker_image }}" | ||||
|     env: | ||||
| @ -14,6 +13,8 @@ steps: | ||||
|       automatic: | ||||
|         - exit_status: -1  # Agent was lost | ||||
|           limit: 5 | ||||
|         - exit_status: -10  # Agent was lost | ||||
|           limit: 5 | ||||
|   - wait | ||||
|  | ||||
|   - group: "AMD Tests" | ||||
| @ -24,7 +25,7 @@ steps: | ||||
|       - label: "AMD: {{ step.label }}" | ||||
|         agents: | ||||
|           queue: amd | ||||
|         command: bash .buildkite/run-amd-test.sh "'cd {{ (step.working_dir or default_working_dir) | safe  }} && {{ step.command  or (step.commands | join(' && ')) | safe }}'" | ||||
|         command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe  }} ; {{ step.command  or (step.commands | join(" ; ")) | safe }}" | ||||
|         env: | ||||
|           DOCKER_BUILDKIT: "1" | ||||
|     {% endif %} | ||||
| @ -39,6 +40,8 @@ steps: | ||||
|  | ||||
|   - label: "Intel Test" | ||||
|     depends_on: ~ | ||||
|     agents: | ||||
|       queue: intel | ||||
|     command: bash .buildkite/run-cpu-test.sh | ||||
|  | ||||
|   {% for step in steps %} | ||||
| @ -53,6 +56,8 @@ steps: | ||||
|       automatic: | ||||
|         - exit_status: -1  # Agent was lost | ||||
|           limit: 5 | ||||
|         - exit_status: -10  # Agent was lost | ||||
|           limit: 5 | ||||
|     plugins: | ||||
|       - kubernetes: | ||||
|           podSpec: | ||||
|  | ||||
							
								
								
									
										26
									
								
								.clang-format
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								.clang-format
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| BasedOnStyle: Google | ||||
| UseTab: Never | ||||
| IndentWidth: 2 | ||||
| ColumnLimit: 80 | ||||
|  | ||||
| # Force pointers to the type for C++. | ||||
| DerivePointerAlignment: false | ||||
| PointerAlignment: Left | ||||
|  | ||||
| # Reordering #include statements can (and currently will) introduce errors | ||||
| SortIncludes: false | ||||
|  | ||||
| # Style choices | ||||
| AlignConsecutiveAssignments: false | ||||
| AlignConsecutiveDeclarations: false | ||||
| IndentPPDirectives: BeforeHash | ||||
|  | ||||
| IncludeCategories: | ||||
|   - Regex:           '^<' | ||||
|     Priority:        4 | ||||
|   - Regex:           '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' | ||||
|     Priority:        3 | ||||
|   - Regex:           '^"(qoda|\.\.)/' | ||||
|     Priority:        2 | ||||
|   - Regex:           '.*' | ||||
|     Priority:        1 | ||||
							
								
								
									
										2
									
								
								.github/ISSUE_TEMPLATE/400-bug report.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ISSUE_TEMPLATE/400-bug report.yml
									
									
									
									
										vendored
									
									
								
							| @ -59,6 +59,8 @@ body: | ||||
|  | ||||
|       Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. | ||||
|  | ||||
|       Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. | ||||
|  | ||||
|       If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. | ||||
|     placeholder: | | ||||
|       A clear and concise description of what the bug is. | ||||
|  | ||||
							
								
								
									
										42
									
								
								.github/workflows/clang-format.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								.github/workflows/clang-format.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| name: clang-format | ||||
|  | ||||
| on: | ||||
|   # Trigger the workflow on push or pull request, | ||||
|   # but only for the main branch | ||||
|   push: | ||||
|     branches: | ||||
|       - main | ||||
|   pull_request: | ||||
|     branches: | ||||
|       - main | ||||
|  | ||||
| jobs: | ||||
|   clang-format: | ||||
|     runs-on: ubuntu-latest | ||||
|     strategy: | ||||
|       matrix: | ||||
|         python-version: ["3.11"] | ||||
|     steps: | ||||
|     - uses: actions/checkout@v2 | ||||
|     - name: Set up Python ${{ matrix.python-version }} | ||||
|       uses: actions/setup-python@v2 | ||||
|       with: | ||||
|         python-version: ${{ matrix.python-version }} | ||||
|     - name: Install dependencies | ||||
|       run: | | ||||
|         python -m pip install --upgrade pip | ||||
|         pip install clang-format==18.1.5 | ||||
|     - name: Running clang-format | ||||
|       run: | | ||||
|         EXCLUDES=( | ||||
|             'csrc/moe/topk_softmax_kernels.cu' | ||||
|             'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' | ||||
|             'csrc/punica/bgmv/bgmv_config.h' | ||||
|             'csrc/punica/bgmv/bgmv_impl.cuh' | ||||
|             'csrc/punica/bgmv/vec_dtypes.cuh' | ||||
|             'csrc/punica/punica_ops.cu' | ||||
|             'csrc/punica/type_convert.h' | ||||
|         ) | ||||
|         find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ | ||||
|             | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ | ||||
|             | xargs clang-format --dry-run --Werror | ||||
							
								
								
									
										1
									
								
								.github/workflows/mypy.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/mypy.yaml
									
									
									
									
										vendored
									
									
								
							| @ -37,6 +37,7 @@ jobs: | ||||
|         mypy vllm/distributed --config-file pyproject.toml | ||||
|         mypy vllm/entrypoints --config-file pyproject.toml | ||||
|         mypy vllm/executor --config-file pyproject.toml | ||||
|         mypy vllm/multimodal --config-file pyproject.toml | ||||
|         mypy vllm/usage --config-file pyproject.toml | ||||
|         mypy vllm/*.py --config-file pyproject.toml | ||||
|         mypy vllm/transformers_utils --config-file pyproject.toml | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -58,6 +58,9 @@ jobs: | ||||
|  | ||||
|       - name: Setup ccache | ||||
|         uses: hendrikmuhs/ccache-action@v1.2 | ||||
|         with: | ||||
|           create-symlink: true | ||||
|           key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} | ||||
|  | ||||
|       - name: Set up Linux Env | ||||
|         if: ${{ runner.os == 'Linux' }} | ||||
|  | ||||
| @ -167,19 +167,47 @@ set(VLLM_EXT_SRC | ||||
|   "csrc/layernorm_kernels.cu" | ||||
|   "csrc/quantization/squeezellm/quant_cuda_kernel.cu" | ||||
|   "csrc/quantization/gptq/q_gemm.cu" | ||||
|   "csrc/quantization/fp8/fp8_cuda_kernels.cu" | ||||
|   "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" | ||||
|   "csrc/quantization/fp8/common.cu" | ||||
|   "csrc/cuda_utils_kernels.cu" | ||||
|   "csrc/moe_align_block_size_kernels.cu" | ||||
|   "csrc/pybind.cpp") | ||||
|  | ||||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|   include(FetchContent) | ||||
|   SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) | ||||
|   FetchContent_Declare( | ||||
|         cutlass | ||||
|         GIT_REPOSITORY https://github.com/nvidia/cutlass.git | ||||
|         # CUTLASS 3.5.0 | ||||
|         GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc | ||||
|   ) | ||||
|   FetchContent_MakeAvailable(cutlass) | ||||
|  | ||||
|   list(APPEND VLLM_EXT_SRC | ||||
|     "csrc/quantization/aqlm/gemm_kernels.cu" | ||||
|     "csrc/quantization/awq/gemm_kernels.cu" | ||||
|     "csrc/quantization/marlin/marlin_cuda_kernel.cu" | ||||
|     "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" | ||||
|     "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" | ||||
|     "csrc/quantization/gptq_marlin/gptq_marlin.cu" | ||||
|     "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" | ||||
|     "csrc/custom_all_reduce.cu") | ||||
|     "csrc/custom_all_reduce.cu" | ||||
|     "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" | ||||
|     "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" | ||||
|     "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") | ||||
|  | ||||
|   # | ||||
|   # The CUTLASS kernels for Hopper require sm90a to be enabled. | ||||
|   # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. | ||||
|   # That adds an extra 17MB to compiled binary, so instead we selectively enable it. | ||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) | ||||
|     set_source_files_properties( | ||||
|           "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" | ||||
|           PROPERTIES | ||||
|           COMPILE_FLAGS | ||||
|           "-gencode arch=compute_90a,code=sm_90a") | ||||
|   endif() | ||||
|  | ||||
| endif() | ||||
|  | ||||
| define_gpu_extension_target( | ||||
| @ -189,6 +217,7 @@ define_gpu_extension_target( | ||||
|   SOURCES ${VLLM_EXT_SRC} | ||||
|   COMPILE_FLAGS ${VLLM_GPU_FLAGS} | ||||
|   ARCHITECTURES ${VLLM_GPU_ARCHES} | ||||
|   INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} | ||||
|   WITH_SOABI) | ||||
|  | ||||
| # | ||||
| @ -219,7 +248,8 @@ set(VLLM_PUNICA_EXT_SRC | ||||
|   "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" | ||||
|   "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" | ||||
|   "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" | ||||
|   "csrc/punica/punica_ops.cc") | ||||
|   "csrc/punica/punica_ops.cu" | ||||
|   "csrc/punica/punica_pybind.cpp") | ||||
|  | ||||
| # | ||||
| # Copy GPU compilation flags+update for punica | ||||
| @ -243,6 +273,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA") | ||||
|     endif() | ||||
|   endforeach() | ||||
|   message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") | ||||
| elseif(${VLLM_GPU_LANG} STREQUAL "HIP") | ||||
|   set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) | ||||
|   message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") | ||||
| endif() | ||||
|  | ||||
| if (VLLM_PUNICA_GPU_ARCHES) | ||||
| @ -277,9 +310,7 @@ add_custom_target(default) | ||||
| if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") | ||||
|   message(STATUS "Enabling C extension.") | ||||
|   add_dependencies(default _C) | ||||
| endif() | ||||
|  | ||||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|   message(STATUS "Enabling moe extension.") | ||||
|   add_dependencies(default _moe_C) | ||||
|  | ||||
|  | ||||
							
								
								
									
										27
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								Dockerfile
									
									
									
									
									
								
							| @ -79,31 +79,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ | ||||
| COPY .buildkite/check-wheel-size.py check-wheel-size.py | ||||
| RUN python3 check-wheel-size.py dist | ||||
|  | ||||
| # the `vllm_nccl` package must be installed from source distribution | ||||
| # pip is too smart to store a wheel in the cache, and other CI jobs | ||||
| # will directly use the wheel from the cache, which is not what we want. | ||||
| # we need to remove it manually | ||||
| RUN --mount=type=cache,target=/root/.cache/pip \ | ||||
|     pip cache remove vllm_nccl* | ||||
| #################### EXTENSION Build IMAGE #################### | ||||
|  | ||||
| #################### FLASH_ATTENTION Build IMAGE #################### | ||||
| FROM dev as flash-attn-builder | ||||
| # max jobs used for build | ||||
| ARG max_jobs=2 | ||||
| ENV MAX_JOBS=${max_jobs} | ||||
| # flash attention version | ||||
| ARG flash_attn_version=v2.5.8 | ||||
| ENV FLASH_ATTN_VERSION=${flash_attn_version} | ||||
|  | ||||
| WORKDIR /usr/src/flash-attention-v2 | ||||
|  | ||||
| # Download the wheel or build it if a pre-compiled release doesn't exist | ||||
| RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ | ||||
|     --no-build-isolation --no-deps --no-cache-dir | ||||
|  | ||||
| #################### FLASH_ATTENTION Build IMAGE #################### | ||||
|  | ||||
| #################### vLLM installation IMAGE #################### | ||||
| # image with vLLM installed | ||||
| FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base | ||||
| @ -122,10 +99,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ | ||||
| RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ | ||||
|     --mount=type=cache,target=/root/.cache/pip \ | ||||
|     pip install dist/*.whl --verbose | ||||
|  | ||||
| RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ | ||||
|     --mount=type=cache,target=/root/.cache/pip \ | ||||
|     pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir | ||||
| #################### vLLM installation IMAGE #################### | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. | ||||
|  | ||||
| FROM ubuntu:22.04 | ||||
| FROM ubuntu:22.04 AS cpu-test-1 | ||||
|  | ||||
| RUN apt-get update  -y \ | ||||
|     && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ | ||||
| @ -9,6 +9,8 @@ RUN apt-get update  -y \ | ||||
| RUN pip install --upgrade pip \ | ||||
|     && pip install wheel packaging ninja setuptools>=49.4.0 numpy | ||||
|  | ||||
| FROM cpu-test-1 AS build | ||||
|  | ||||
| COPY ./ /workspace/vllm | ||||
|  | ||||
| WORKDIR /workspace/vllm | ||||
| @ -17,4 +19,8 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py | ||||
|  | ||||
| RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install | ||||
|  | ||||
| WORKDIR /workspace/ | ||||
|  | ||||
| RUN ln -s /workspace/vllm/tests  && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks | ||||
|  | ||||
| CMD ["/bin/bash"] | ||||
|  | ||||
| @ -92,16 +92,24 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ | ||||
| WORKDIR /vllm-workspace | ||||
| COPY . . | ||||
|  | ||||
| #RUN python3 -m pip install pynvml # to be removed eventually | ||||
| RUN python3 -m pip install --upgrade pip numba | ||||
|  | ||||
| # make sure punica kernels are built (for LoRA) | ||||
| ENV VLLM_INSTALL_PUNICA_KERNELS=1 | ||||
| # Workaround for ray >= 2.10.0 | ||||
| ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 | ||||
|  | ||||
| ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so | ||||
|  | ||||
| RUN --mount=type=cache,target=/root/.cache/pip \ | ||||
|     pip install -U -r requirements-rocm.txt \ | ||||
|     && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ | ||||
|     && python3 setup.py install \ | ||||
|     && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ | ||||
|     && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ | ||||
|     && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ | ||||
|     && cd .. | ||||
|  | ||||
| RUN python3 -m pip install --upgrade pip | ||||
| RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 | ||||
|  | ||||
| CMD ["/bin/bash"] | ||||
|  | ||||
							
								
								
									
										82
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										82
									
								
								README.md
									
									
									
									
									
								
							| @ -14,6 +14,17 @@ Easy, fast, and cheap LLM serving for everyone | ||||
|  | ||||
| </p> | ||||
|  | ||||
| --- | ||||
|  | ||||
| **The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)** | ||||
|  | ||||
| We are thrilled to announce our fourth vLLM Meetup! | ||||
| The vLLM team will share recent updates and roadmap. | ||||
| We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM. | ||||
| Please register [here](https://lu.ma/agivllm) and join us! | ||||
|  | ||||
| --- | ||||
|  | ||||
| *Latest News* 🔥 | ||||
| - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). | ||||
| - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). | ||||
| @ -51,41 +62,14 @@ vLLM is flexible and easy to use with: | ||||
| - (Experimental) Prefix caching support | ||||
| - (Experimental) Multi-lora support | ||||
|  | ||||
| vLLM seamlessly supports many Hugging Face models, including the following architectures: | ||||
| vLLM seamlessly supports most popular open-source models on HuggingFace, including: | ||||
| - Transformer-like LLMs (e.g., Llama) | ||||
| - Mixture-of-Expert LLMs (e.g., Mixtral) | ||||
| - Multi-modal LLMs (e.g., LLaVA) | ||||
|  | ||||
| - Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) | ||||
| - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) | ||||
| - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) | ||||
| - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) | ||||
| - Command-R (`CohereForAI/c4ai-command-r-v01`, etc.) | ||||
| - DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) | ||||
| - DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) | ||||
| - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) | ||||
| - Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) | ||||
| - GPT-2 (`gpt2`, `gpt2-xl`, etc.) | ||||
| - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) | ||||
| - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) | ||||
| - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) | ||||
| - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) | ||||
| - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) | ||||
| - Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) | ||||
| - LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) | ||||
| - MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) | ||||
| - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) | ||||
| - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) | ||||
| - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) | ||||
| - OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) | ||||
| - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) | ||||
| - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) | ||||
| - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) | ||||
| - Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) | ||||
| - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) | ||||
| - Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) | ||||
| - Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) | ||||
| - StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) | ||||
| - Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) | ||||
| - Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) | ||||
| - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) | ||||
| Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). | ||||
|  | ||||
| ## Getting Started | ||||
|  | ||||
| Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): | ||||
|  | ||||
| @ -93,9 +77,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get | ||||
| pip install vllm | ||||
| ``` | ||||
|  | ||||
| ## Getting Started | ||||
|  | ||||
| Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. | ||||
| Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. | ||||
| - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) | ||||
| - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) | ||||
| - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) | ||||
| @ -105,6 +87,32 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started | ||||
| We welcome and value any contributions and collaborations. | ||||
| Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. | ||||
|  | ||||
| ## Sponsors | ||||
|  | ||||
| vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! | ||||
|  | ||||
| <!-- Note: Please sort them in alphabetical order. --> | ||||
| <!-- Note: Please keep these consistent with docs/source/community/sponsors.md --> | ||||
|  | ||||
| - a16z | ||||
| - AMD | ||||
| - Anyscale | ||||
| - AWS | ||||
| - Crusoe Cloud | ||||
| - Databricks | ||||
| - DeepInfra | ||||
| - Dropbox | ||||
| - Lambda Lab | ||||
| - NVIDIA | ||||
| - Replicate | ||||
| - Roblox | ||||
| - RunPod | ||||
| - Trainy | ||||
| - UC Berkeley | ||||
| - UC San Diego | ||||
|  | ||||
| We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. | ||||
|  | ||||
| ## Citation | ||||
|  | ||||
| If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): | ||||
|  | ||||
| @ -89,6 +89,9 @@ async def async_request_tgi( | ||||
|                     output.latency = most_recent_timestamp - st | ||||
|                     output.success = True | ||||
|                     output.generated_text = data["generated_text"] | ||||
|                 else: | ||||
|                     output.error = response.reason or "" | ||||
|                     output.success = False | ||||
|         except Exception: | ||||
|             output.success = False | ||||
|             exc_info = sys.exc_info() | ||||
| @ -276,6 +279,9 @@ async def async_request_openai_completions( | ||||
|                     output.generated_text = generated_text | ||||
|                     output.success = True | ||||
|                     output.latency = latency | ||||
|                 else: | ||||
|                     output.error = response.reason or "" | ||||
|                     output.success = False | ||||
|         except Exception: | ||||
|             output.success = False | ||||
|             exc_info = sys.exc_info() | ||||
|  | ||||
| @ -1,14 +1,16 @@ | ||||
| """Benchmark the latency of processing a single batch of requests.""" | ||||
| import argparse | ||||
| import json | ||||
| import time | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
| from typing import List, Optional | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from vllm import LLM, SamplingParams | ||||
| from vllm.inputs import PromptStrictInputs | ||||
| from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS | ||||
|  | ||||
|  | ||||
| @ -18,6 +20,8 @@ def main(args: argparse.Namespace): | ||||
|     # NOTE(woosuk): If the request cannot be processed in a single batch, | ||||
|     # the engine will automatically process the request in multiple batches. | ||||
|     llm = LLM(model=args.model, | ||||
|               speculative_model=args.speculative_model, | ||||
|               num_speculative_tokens=args.num_speculative_tokens, | ||||
|               tokenizer=args.tokenizer, | ||||
|               quantization=args.quantization, | ||||
|               tensor_parallel_size=args.tensor_parallel_size, | ||||
| @ -28,9 +32,11 @@ def main(args: argparse.Namespace): | ||||
|               quantization_param_path=args.quantization_param_path, | ||||
|               device=args.device, | ||||
|               ray_workers_use_nsight=args.ray_workers_use_nsight, | ||||
|               use_v2_block_manager=args.use_v2_block_manager, | ||||
|               enable_chunked_prefill=args.enable_chunked_prefill, | ||||
|               download_dir=args.download_dir, | ||||
|               block_size=args.block_size) | ||||
|               block_size=args.block_size, | ||||
|               gpu_memory_utilization=args.gpu_memory_utilization) | ||||
|  | ||||
|     sampling_params = SamplingParams( | ||||
|         n=args.n, | ||||
| @ -44,7 +50,9 @@ def main(args: argparse.Namespace): | ||||
|     dummy_prompt_token_ids = np.random.randint(10000, | ||||
|                                                size=(args.batch_size, | ||||
|                                                      args.input_len)) | ||||
|     dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() | ||||
|     dummy_inputs: List[PromptStrictInputs] = [{ | ||||
|         "prompt_token_ids": batch | ||||
|     } for batch in dummy_prompt_token_ids.tolist()] | ||||
|  | ||||
|     def run_to_completion(profile_dir: Optional[str] = None): | ||||
|         if profile_dir: | ||||
| @ -55,13 +63,13 @@ def main(args: argparse.Namespace): | ||||
|                     ], | ||||
|                     on_trace_ready=torch.profiler.tensorboard_trace_handler( | ||||
|                         str(profile_dir))) as p: | ||||
|                 llm.generate(prompt_token_ids=dummy_prompt_token_ids, | ||||
|                 llm.generate(dummy_inputs, | ||||
|                              sampling_params=sampling_params, | ||||
|                              use_tqdm=False) | ||||
|             print(p.key_averages()) | ||||
|         else: | ||||
|             start_time = time.perf_counter() | ||||
|             llm.generate(prompt_token_ids=dummy_prompt_token_ids, | ||||
|             llm.generate(dummy_inputs, | ||||
|                          sampling_params=sampling_params, | ||||
|                          use_tqdm=False) | ||||
|             end_time = time.perf_counter() | ||||
| @ -93,12 +101,24 @@ def main(args: argparse.Namespace): | ||||
|     for percentage, percentile in zip(percentages, percentiles): | ||||
|         print(f'{percentage}% percentile latency: {percentile} seconds') | ||||
|  | ||||
|     # Output JSON results if specified | ||||
|     if args.output_json: | ||||
|         results = { | ||||
|             "avg_latency": np.mean(latencies), | ||||
|             "latencies": latencies.tolist(), | ||||
|             "percentiles": dict(zip(percentages, percentiles.tolist())), | ||||
|         } | ||||
|         with open(args.output_json, "w") as f: | ||||
|             json.dump(results, f, indent=4) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description='Benchmark the latency of processing a single batch of ' | ||||
|         'requests till completion.') | ||||
|     parser.add_argument('--model', type=str, default='facebook/opt-125m') | ||||
|     parser.add_argument('--speculative-model', type=str, default=None) | ||||
|     parser.add_argument('--num-speculative-tokens', type=int, default=None) | ||||
|     parser.add_argument('--tokenizer', type=str, default=None) | ||||
|     parser.add_argument('--quantization', | ||||
|                         '-q', | ||||
| @ -137,15 +157,13 @@ if __name__ == '__main__': | ||||
|                         action='store_true', | ||||
|                         help='enforce eager mode and disable CUDA graph') | ||||
|     parser.add_argument( | ||||
|         "--kv-cache-dtype", | ||||
|         '--kv-cache-dtype', | ||||
|         type=str, | ||||
|         choices=['auto', 'fp8'], | ||||
|         default='auto', | ||||
|         help= | ||||
|         'Data type for kv cache storage. If "auto", will use model data type. ' | ||||
|         'FP8_E5M2 (without scaling) is only supported on cuda version greater ' | ||||
|         'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' | ||||
|         'common inference criteria.') | ||||
|         choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], | ||||
|         default="auto", | ||||
|         help='Data type for kv cache storage. If "auto", will use model ' | ||||
|         'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' | ||||
|         'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') | ||||
|     parser.add_argument( | ||||
|         '--quantization-param-path', | ||||
|         type=str, | ||||
| @ -181,6 +199,7 @@ if __name__ == '__main__': | ||||
|         action='store_true', | ||||
|         help='If True, the prefill requests can be chunked based on the ' | ||||
|         'max_num_batched_tokens') | ||||
|     parser.add_argument('--use-v2-block-manager', action='store_true') | ||||
|     parser.add_argument( | ||||
|         "--ray-workers-use-nsight", | ||||
|         action='store_true', | ||||
| @ -191,5 +210,16 @@ if __name__ == '__main__': | ||||
|                         default=None, | ||||
|                         help='directory to download and load the weights, ' | ||||
|                         'default to the default cache dir of huggingface') | ||||
|     parser.add_argument( | ||||
|         '--output-json', | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help='Path to save the latency results in JSON format.') | ||||
|     parser.add_argument('--gpu-memory-utilization', | ||||
|                         type=float, | ||||
|                         default=0.9, | ||||
|                         help='the fraction of GPU memory to be used for ' | ||||
|                         'the model executor, which can range from 0 to 1.' | ||||
|                         'If unspecified, will use the default value of 0.9.') | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
|  | ||||
| @ -17,6 +17,10 @@ On the client side, run: | ||||
|         --dataset-path <path to dataset> \ | ||||
|         --request-rate <request_rate> \ # By default <request_rate> is inf | ||||
|         --num-prompts <num_prompts> # By default <num_prompts> is 1000 | ||||
|          | ||||
|     when using tgi backend, add | ||||
|         --endpoint /generate_stream | ||||
|     to the end of the command above. | ||||
| """ | ||||
| import argparse | ||||
| import asyncio | ||||
| @ -211,6 +215,11 @@ def calculate_metrics( | ||||
|         else: | ||||
|             actual_output_lens.append(0) | ||||
|  | ||||
|     if completed == 0: | ||||
|         warnings.warn( | ||||
|             "All requests failed. This is likely due to a misconfiguration " | ||||
|             "on the benchmark arguments.", | ||||
|             stacklevel=2) | ||||
|     metrics = BenchmarkMetrics( | ||||
|         completed=completed, | ||||
|         total_input=total_input, | ||||
| @ -222,9 +231,9 @@ def calculate_metrics( | ||||
|         1000,  # ttfts is empty if streaming is not supported by backend | ||||
|         median_ttft_ms=np.median(ttfts or 0) * 1000, | ||||
|         p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, | ||||
|         mean_tpot_ms=np.mean(tpots) * 1000, | ||||
|         median_tpot_ms=np.median(tpots) * 1000, | ||||
|         p99_tpot_ms=np.percentile(tpots, 99) * 1000, | ||||
|         mean_tpot_ms=np.mean(tpots or 0) * 1000, | ||||
|         median_tpot_ms=np.median(tpots or 0) * 1000, | ||||
|         p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, | ||||
|     ) | ||||
|  | ||||
|     return metrics, actual_output_lens | ||||
| @ -246,6 +255,24 @@ async def benchmark( | ||||
|     else: | ||||
|         raise ValueError(f"Unknown backend: {backend}") | ||||
|  | ||||
|     print("Starting initial single prompt test run...") | ||||
|     test_prompt, test_prompt_len, test_output_len = input_requests[0] | ||||
|     test_input = RequestFuncInput( | ||||
|         model=model_id, | ||||
|         prompt=test_prompt, | ||||
|         api_url=api_url, | ||||
|         prompt_len=test_prompt_len, | ||||
|         output_len=test_output_len, | ||||
|         best_of=best_of, | ||||
|         use_beam_search=use_beam_search, | ||||
|     ) | ||||
|     test_output = await request_func(request_func_input=test_input) | ||||
|     if not test_output.success: | ||||
|         raise ValueError( | ||||
|             "Initial test run failed - Please make sure benchmark arguments " | ||||
|             f"are correctly specified. Error: {test_output.error}") | ||||
|     else: | ||||
|         print("Initial test run completed. Starting main benchmark run...") | ||||
|     print(f"Traffic request rate: {request_rate}") | ||||
|  | ||||
|     pbar = None if disable_tqdm else tqdm(total=len(input_requests)) | ||||
|  | ||||
| @ -242,6 +242,18 @@ def main(args: argparse.Namespace): | ||||
|     print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " | ||||
|           f"{total_num_tokens / elapsed_time:.2f} tokens/s") | ||||
|  | ||||
|     # Output JSON results if specified | ||||
|     if args.output_json: | ||||
|         results = { | ||||
|             "elapsed_time": elapsed_time, | ||||
|             "num_requests": len(requests), | ||||
|             "total_num_tokens": total_num_tokens, | ||||
|             "requests_per_second": len(requests) / elapsed_time, | ||||
|             "tokens_per_second": total_num_tokens / elapsed_time, | ||||
|         } | ||||
|         with open(args.output_json, "w") as f: | ||||
|             json.dump(results, f, indent=4) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Benchmark the throughput.") | ||||
| @ -311,15 +323,13 @@ if __name__ == "__main__": | ||||
|                         action="store_true", | ||||
|                         help="enforce eager execution") | ||||
|     parser.add_argument( | ||||
|         "--kv-cache-dtype", | ||||
|         '--kv-cache-dtype', | ||||
|         type=str, | ||||
|         choices=["auto", "fp8"], | ||||
|         choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], | ||||
|         default="auto", | ||||
|         help= | ||||
|         'Data type for kv cache storage. If "auto", will use model data type. ' | ||||
|         'FP8_E5M2 (without scaling) is only supported on cuda version greater ' | ||||
|         'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' | ||||
|         'common inference criteria.') | ||||
|         help='Data type for kv cache storage. If "auto", will use model ' | ||||
|         'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' | ||||
|         'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') | ||||
|     parser.add_argument( | ||||
|         '--quantization-param-path', | ||||
|         type=str, | ||||
| @ -353,6 +363,11 @@ if __name__ == "__main__": | ||||
|                         default=None, | ||||
|                         help='directory to download and load the weights, ' | ||||
|                         'default to the default cache dir of huggingface') | ||||
|     parser.add_argument( | ||||
|         '--output-json', | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help='Path to save the throughput results in JSON format.') | ||||
|     args = parser.parse_args() | ||||
|     if args.tokenizer is None: | ||||
|         args.tokenizer = args.model | ||||
|  | ||||
							
								
								
									
										352
									
								
								benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										352
									
								
								benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,352 @@ | ||||
| import argparse | ||||
| import copy | ||||
| import itertools | ||||
| import pickle as pkl | ||||
| import time | ||||
| from typing import Callable, Iterable, List, Tuple | ||||
|  | ||||
| import torch | ||||
| import torch.utils.benchmark as TBenchmark | ||||
| from torch.utils.benchmark import Measurement as TMeasurement | ||||
| from weight_shapes import WEIGHT_SHAPES | ||||
|  | ||||
| from vllm import _custom_ops as ops | ||||
|  | ||||
| DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] | ||||
| DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] | ||||
| DEFAULT_TP_SIZES = [1] | ||||
|  | ||||
| # helpers | ||||
|  | ||||
|  | ||||
| def to_fp8(tensor: torch.tensor) -> torch.tensor: | ||||
|     finfo = torch.finfo(torch.float8_e4m3fn) | ||||
|     return torch.round(tensor.clamp( | ||||
|         min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) | ||||
|  | ||||
|  | ||||
| def to_int8(tensor: torch.tensor) -> torch.tensor: | ||||
|     return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) | ||||
|  | ||||
|  | ||||
| def make_rand_tensors(dtype: torch.dtype, m: int, n: int, | ||||
|                       k: int) -> Tuple[torch.tensor, torch.tensor]: | ||||
|  | ||||
|     a = torch.randn((m, k), device='cuda') * 5 | ||||
|     b = torch.randn((n, k), device='cuda').t() * 5 | ||||
|  | ||||
|     if dtype == torch.int8: | ||||
|         return to_int8(a), to_int8(b) | ||||
|     if dtype == torch.float8_e4m3fn: | ||||
|         return to_fp8(a), to_fp8(b) | ||||
|  | ||||
|     raise ValueError("unsupported dtype") | ||||
|  | ||||
|  | ||||
| # impl | ||||
|  | ||||
|  | ||||
| def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, | ||||
|                     scale_b: torch.tensor, | ||||
|                     out_dtype: torch.dtype) -> torch.tensor: | ||||
|     return torch.mm(a, b) | ||||
|  | ||||
|  | ||||
| def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, | ||||
|                      scale_b: torch.tensor, | ||||
|                      out_dtype: torch.dtype) -> torch.tensor: | ||||
|     return torch._scaled_mm(a, | ||||
|                             b, | ||||
|                             scale_a=scale_a, | ||||
|                             scale_b=scale_b, | ||||
|                             out_dtype=out_dtype) | ||||
|  | ||||
|  | ||||
| def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, | ||||
|                                 scale_a: torch.tensor, scale_b: torch.tensor, | ||||
|                                 out_dtype: torch.dtype) -> torch.tensor: | ||||
|     return torch._scaled_mm(a, | ||||
|                             b, | ||||
|                             scale_a=scale_a, | ||||
|                             scale_b=scale_b, | ||||
|                             out_dtype=out_dtype, | ||||
|                             use_fast_accum=True) | ||||
|  | ||||
|  | ||||
| def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, | ||||
|                  scale_b: torch.tensor, | ||||
|                  out_dtype: torch.dtype) -> torch.tensor: | ||||
|     return ops.cutlass_scaled_mm_dq(a, | ||||
|                                     b, | ||||
|                                     scale_a, | ||||
|                                     scale_b, | ||||
|                                     out_dtype=out_dtype) | ||||
|  | ||||
|  | ||||
| # bench | ||||
| def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, | ||||
|              scale_b: torch.tensor, out_dtype: torch.dtype, label: str, | ||||
|              sub_label: str, fn: Callable, description: str) -> TMeasurement: | ||||
|  | ||||
|     min_run_time = 1 | ||||
|  | ||||
|     globals = { | ||||
|         "a": a, | ||||
|         "b": b, | ||||
|         "scale_a": scale_a, | ||||
|         "scale_b": scale_b, | ||||
|         "out_dtype": out_dtype, | ||||
|         "fn": fn, | ||||
|     } | ||||
|     return TBenchmark.Timer( | ||||
|         stmt="fn(a, b, scale_a, scale_b, out_dtype)", | ||||
|         globals=globals, | ||||
|         label=label, | ||||
|         sub_label=sub_label, | ||||
|         description=description, | ||||
|     ).blocked_autorange(min_run_time=min_run_time) | ||||
|  | ||||
|  | ||||
| def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, | ||||
|                sub_label: str) -> Iterable[TMeasurement]: | ||||
|     assert dtype == torch.int8 | ||||
|     a, b = make_rand_tensors(torch.int8, m, n, k) | ||||
|     scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||
|     scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||
|  | ||||
|     timers = [] | ||||
|     # pytorch impl | ||||
|     timers.append( | ||||
|         bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), | ||||
|                  b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, | ||||
|                  torch.bfloat16, label, sub_label, pytorch_i8_impl, | ||||
|                  "pytorch_bf16_bf16_bf16_matmul-no-scales")) | ||||
|  | ||||
|     # cutlass impl | ||||
|     timers.append( | ||||
|         bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), | ||||
|                  torch.bfloat16, label, sub_label, cutlass_impl, | ||||
|                  "cutlass_i8_i8_bf16_scaled_mm")) | ||||
|  | ||||
|     return timers | ||||
|  | ||||
|  | ||||
| def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, | ||||
|               sub_label: str) -> Iterable[TMeasurement]: | ||||
|     assert dtype == torch.float8_e4m3fn | ||||
|     a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) | ||||
|     scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||
|     scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||
|  | ||||
|     timers = [] | ||||
|  | ||||
|     # pytorch impl: bf16 output, without fp8 fast accum | ||||
|     timers.append( | ||||
|         bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, | ||||
|                  pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) | ||||
|  | ||||
|     # pytorch impl: bf16 output, with fp8 fast accum | ||||
|     timers.append( | ||||
|         bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, | ||||
|                  pytorch_fp8_impl_fast_accum, | ||||
|                  "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) | ||||
|  | ||||
|     # pytorch impl: fp16 output, without fp8 fast accum | ||||
|     timers.append( | ||||
|         bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, | ||||
|                  pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) | ||||
|  | ||||
|     # pytorch impl: fp16 output, with fp8 fast accum | ||||
|     timers.append( | ||||
|         bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, | ||||
|                  pytorch_fp8_impl_fast_accum, | ||||
|                  "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) | ||||
|  | ||||
|     # cutlass impl: bf16 output | ||||
|     timers.append( | ||||
|         bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), | ||||
|                  torch.bfloat16, label, sub_label, cutlass_impl, | ||||
|                  "cutlass_fp8_fp8_bf16_scaled_mm")) | ||||
|     # cutlass impl: fp16 output | ||||
|     timers.append( | ||||
|         bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), | ||||
|                  torch.float16, label, sub_label, cutlass_impl, | ||||
|                  "cutlass_fp8_fp8_fp16_scaled_mm")) | ||||
|     return timers | ||||
|  | ||||
|  | ||||
| def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, | ||||
|           sub_label: str) -> Iterable[TMeasurement]: | ||||
|     if dtype == torch.int8: | ||||
|         return bench_int8(dtype, m, k, n, label, sub_label) | ||||
|     if dtype == torch.float8_e4m3fn: | ||||
|         return bench_fp8(dtype, m, k, n, label, sub_label) | ||||
|     raise ValueError("unsupported type") | ||||
|  | ||||
|  | ||||
| # runner | ||||
| def print_timers(timers: Iterable[TMeasurement]): | ||||
|     compare = TBenchmark.Compare(timers) | ||||
|     compare.print() | ||||
|  | ||||
|  | ||||
| def run(dtype: torch.dtype, | ||||
|         MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: | ||||
|  | ||||
|     results = [] | ||||
|     for m, k, n in MKNs: | ||||
|         timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", | ||||
|                        f"MKN=({m}x{k}x{n})") | ||||
|         print_timers(timers) | ||||
|         results.extend(timers) | ||||
|  | ||||
|     return results | ||||
|  | ||||
|  | ||||
| # output makers | ||||
| def make_output(data: Iterable[TMeasurement], | ||||
|                 MKNs: Iterable[Tuple[int, int, int]], | ||||
|                 base_description: str, | ||||
|                 timestamp=None): | ||||
|  | ||||
|     print(f"== All Results {base_description} ====") | ||||
|     print_timers(data) | ||||
|  | ||||
|     # pickle all the results | ||||
|     timestamp = int(time.time()) if timestamp is None else timestamp | ||||
|     with open(f"{base_description}-{timestamp}.pkl", "wb") as f: | ||||
|         pkl.dump(data, f) | ||||
|  | ||||
|  | ||||
| # argparse runners | ||||
|  | ||||
|  | ||||
| def run_square_bench(args): | ||||
|     dim_sizes = list( | ||||
|         range(args.dim_start, args.dim_end + 1, args.dim_increment)) | ||||
|     MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) | ||||
|     data = run(args.dtype, MKNs) | ||||
|  | ||||
|     make_output(data, MKNs, f"square_bench-{args.dtype}") | ||||
|  | ||||
|  | ||||
| def run_range_bench(args): | ||||
|     dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) | ||||
|     n = len(dim_sizes) | ||||
|     Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes | ||||
|     Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes | ||||
|     Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes | ||||
|     MKNs = list(zip(Ms, Ks, Ns)) | ||||
|     data = run(args.dtype, MKNs) | ||||
|  | ||||
|     make_output(data, MKNs, f"range_bench-{args.dtype}") | ||||
|  | ||||
|  | ||||
| def run_model_bench(args): | ||||
|  | ||||
|     print("Benchmarking models:") | ||||
|     for i, model in enumerate(args.models): | ||||
|         print(f"[{i}]  {model}") | ||||
|  | ||||
|     def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: | ||||
|         KNs = [] | ||||
|         for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): | ||||
|             KN[tp_split_dim] = KN[tp_split_dim] // tp_size | ||||
|             KNs.append(KN) | ||||
|         return KNs | ||||
|  | ||||
|     model_bench_data = [] | ||||
|     models_tps = list(itertools.product(args.models, args.tp_sizes)) | ||||
|     for model, tp_size in models_tps: | ||||
|         Ms = args.batch_sizes | ||||
|         KNs = model_shapes(model, tp_size) | ||||
|         MKNs = [] | ||||
|         for m in Ms: | ||||
|             for k, n in KNs: | ||||
|                 MKNs.append((m, k, n)) | ||||
|  | ||||
|         data = run(args.dtype, MKNs) | ||||
|         model_bench_data.append(data) | ||||
|  | ||||
|     # Print all results | ||||
|     for data, model_tp in zip(model_bench_data, models_tps): | ||||
|         model, tp_size = model_tp | ||||
|         print(f"== Results {args.dtype} {model}-TP{tp_size} ====") | ||||
|         print_timers(data) | ||||
|  | ||||
|     timestamp = int(time.time()) | ||||
|  | ||||
|     all_data = [] | ||||
|     for d in model_bench_data: | ||||
|         all_data.extend(d) | ||||
|     # pickle all data | ||||
|     with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: | ||||
|         pkl.dump(all_data, f) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     def to_torch_dtype(dt): | ||||
|         if dt == "int8": | ||||
|             return torch.int8 | ||||
|         if dt == "fp8": | ||||
|             return torch.float8_e4m3fn | ||||
|         raise ValueError("unsupported dtype") | ||||
|  | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description=""" | ||||
| Benchmark Cutlass GEMM. | ||||
|  | ||||
|     To run square GEMMs: | ||||
|         python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 | ||||
|      | ||||
|     To run constant N and K and sweep M: | ||||
|         python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 | ||||
|      | ||||
|     To run dimensions from a model: | ||||
|         python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 | ||||
|      | ||||
|     Output: | ||||
|         - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. | ||||
|             """,  # noqa: E501 | ||||
|         formatter_class=argparse.RawTextHelpFormatter) | ||||
|  | ||||
|     parser.add_argument("--dtype", | ||||
|                         type=to_torch_dtype, | ||||
|                         required=True, | ||||
|                         help="Available options are ['int8', 'fp8']") | ||||
|     subparsers = parser.add_subparsers(dest="cmd") | ||||
|  | ||||
|     square_parser = subparsers.add_parser("square_bench") | ||||
|     square_parser.add_argument("--dim-start", type=int, required=True) | ||||
|     square_parser.add_argument("--dim-end", type=int, required=True) | ||||
|     square_parser.add_argument("--dim-increment", type=int, required=True) | ||||
|     square_parser.set_defaults(func=run_square_bench) | ||||
|  | ||||
|     range_parser = subparsers.add_parser("range_bench") | ||||
|     range_parser.add_argument("--dim-start", type=int, required=True) | ||||
|     range_parser.add_argument("--dim-end", type=int, required=True) | ||||
|     range_parser.add_argument("--dim-increment", type=int, required=True) | ||||
|     range_parser.add_argument("--m-constant", type=int, default=None) | ||||
|     range_parser.add_argument("--n-constant", type=int, default=None) | ||||
|     range_parser.add_argument("--k-constant", type=int, default=None) | ||||
|     range_parser.set_defaults(func=run_range_bench) | ||||
|  | ||||
|     model_parser = subparsers.add_parser("model_bench") | ||||
|     model_parser.add_argument("--models", | ||||
|                               nargs="+", | ||||
|                               type=str, | ||||
|                               default=DEFAULT_MODELS, | ||||
|                               choices=WEIGHT_SHAPES.keys()) | ||||
|     model_parser.add_argument("--tp-sizes", | ||||
|                               nargs="+", | ||||
|                               type=int, | ||||
|                               default=DEFAULT_TP_SIZES) | ||||
|     model_parser.add_argument("--batch-sizes", | ||||
|                               nargs="+", | ||||
|                               type=int, | ||||
|                               default=DEFAULT_BATCH_SIZES) | ||||
|     model_parser.set_defaults(func=run_model_bench) | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|     args.func(args) | ||||
							
								
								
									
										37
									
								
								benchmarks/cutlass_benchmarks/weight_shapes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								benchmarks/cutlass_benchmarks/weight_shapes.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,37 @@ | ||||
| # Weight Shapes are in the format | ||||
| # ([K, N], TP_SPLIT_DIM) | ||||
| # Example: | ||||
| #  A shape of ([14336, 4096], 0) indicates the following GEMM shape, | ||||
| #   - TP1 : K = 14336, N = 4096 | ||||
| #   - TP2 : K = 7168, N = 4096 | ||||
| #  A shape of ([4096, 6144], 1) indicates the following GEMM shape, | ||||
| #   - TP1 : K = 4096, N = 6144 | ||||
| #   - TP4 : K = 4096, N = 1536 | ||||
|  | ||||
| # TP1 shapes | ||||
| WEIGHT_SHAPES = { | ||||
|     "mistralai/Mistral-7B-v0.1": [ | ||||
|         ([4096, 6144], 1), | ||||
|         ([4096, 4096], 0), | ||||
|         ([4096, 28672], 1), | ||||
|         ([14336, 4096], 0), | ||||
|     ], | ||||
|     "meta-llama/Llama-2-7b-hf": [ | ||||
|         ([4096, 12288], 1), | ||||
|         ([4096, 4096], 0), | ||||
|         ([4096, 22016], 1), | ||||
|         ([11008, 4096], 0), | ||||
|     ], | ||||
|     "meta-llama/Llama-2-13b-hf": [ | ||||
|         ([5120, 15360], 1), | ||||
|         ([5120, 5120], 0), | ||||
|         ([5120, 27648], 1), | ||||
|         ([13824, 5120], 0), | ||||
|     ], | ||||
|     "meta-llama/Llama-2-70b-hf": [ | ||||
|         ([8192, 10240], 1), | ||||
|         ([8192, 8192], 0), | ||||
|         ([8192, 57344], 1), | ||||
|         ([28672, 8192], 0), | ||||
|     ], | ||||
| } | ||||
							
								
								
									
										233
									
								
								benchmarks/kernels/benchmark_marlin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								benchmarks/kernels/benchmark_marlin.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,233 @@ | ||||
| import argparse | ||||
|  | ||||
| import torch | ||||
| import torch.utils.benchmark as benchmark | ||||
| from benchmark_shapes import WEIGHT_SHAPES | ||||
|  | ||||
| from vllm import _custom_ops as ops | ||||
| from vllm.model_executor.layers.quantization.gptq_marlin import ( | ||||
|     GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, | ||||
|     GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) | ||||
| from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( | ||||
|     GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, | ||||
|     GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) | ||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | ||||
|     MarlinWorkspace, marlin_24_quantize, marlin_quantize) | ||||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||||
|     gptq_pack, quantize_weights, sort_weights) | ||||
|  | ||||
| DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] | ||||
| DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] | ||||
|  | ||||
| ACT_ORDER_OPTS = [False, True] | ||||
| K_FULL_OPTS = [False, True] | ||||
|  | ||||
|  | ||||
| def bench_run(results, model, act_order, is_k_full, num_bits, group_size, | ||||
|               size_m, size_k, size_n): | ||||
|     label = "Quant Matmul" | ||||
|  | ||||
|     sub_label = ("{}, act={} k_full={}, b={}, g={}, " | ||||
|                  "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, | ||||
|                                          group_size, size_m, size_k, size_n)) | ||||
|  | ||||
|     print(f"Testing: {sub_label}") | ||||
|  | ||||
|     a = torch.randn(size_m, size_k).to(torch.half).cuda() | ||||
|     b = torch.rand(size_k, size_n).to(torch.half).cuda() | ||||
|  | ||||
|     a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) | ||||
|  | ||||
|     # Marlin quant | ||||
|     ( | ||||
|         marlin_w_ref, | ||||
|         marlin_q_w, | ||||
|         marlin_s, | ||||
|         marlin_g_idx, | ||||
|         marlin_sort_indices, | ||||
|         marlin_rand_perm, | ||||
|     ) = marlin_quantize(b, num_bits, group_size, act_order) | ||||
|  | ||||
|     # Marlin_24 quant | ||||
|     (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, | ||||
|      marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) | ||||
|  | ||||
|     # GPTQ quant | ||||
|     (w_ref, q_w, s, g_idx, | ||||
|      rand_perm) = quantize_weights(b, num_bits, group_size, act_order) | ||||
|     q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) | ||||
|  | ||||
|     # For act_order, sort the "weights" and "g_idx" | ||||
|     # so that group ids are increasing | ||||
|     repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) | ||||
|     if act_order: | ||||
|         (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) | ||||
|  | ||||
|     # Prepare | ||||
|     marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, | ||||
|                                        GPTQ_MARLIN_MAX_PARALLEL) | ||||
|  | ||||
|     marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, | ||||
|                                           GPTQ_MARLIN_24_MAX_PARALLEL) | ||||
|  | ||||
|     globals = { | ||||
|         # Gen params | ||||
|         "num_bits": num_bits, | ||||
|         "group_size": group_size, | ||||
|         "size_m": size_m, | ||||
|         "size_n": size_n, | ||||
|         "size_k": size_k, | ||||
|         "a": a, | ||||
|         "a_tmp": a_tmp, | ||||
|         # Marlin params | ||||
|         "marlin_w_ref": marlin_w_ref, | ||||
|         "marlin_q_w": marlin_q_w, | ||||
|         "marlin_s": marlin_s, | ||||
|         "marlin_g_idx": marlin_g_idx, | ||||
|         "marlin_sort_indices": marlin_sort_indices, | ||||
|         "marlin_rand_perm": marlin_rand_perm, | ||||
|         "marlin_workspace": marlin_workspace, | ||||
|         "is_k_full": is_k_full, | ||||
|         # Marlin_24 params | ||||
|         "marlin_24_w_ref": marlin_24_w_ref, | ||||
|         "marlin_24_q_w_comp": marlin_24_q_w_comp, | ||||
|         "marlin_24_meta": marlin_24_meta, | ||||
|         "marlin_24_s": marlin_24_s, | ||||
|         "marlin_24_workspace": marlin_24_workspace, | ||||
|         # GPTQ params | ||||
|         "q_w_gptq": q_w_gptq, | ||||
|         "repack_sort_indices": repack_sort_indices, | ||||
|         # Kernels | ||||
|         "gptq_marlin_gemm": ops.gptq_marlin_gemm, | ||||
|         "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, | ||||
|         "gptq_marlin_repack": ops.gptq_marlin_repack, | ||||
|     } | ||||
|  | ||||
|     min_run_time = 1 | ||||
|  | ||||
|     # Warmup pytorch | ||||
|     for i in range(5): | ||||
|         torch.matmul(a, marlin_w_ref) | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt="torch.matmul(a, marlin_w_ref)", | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
|             description="pytorch_gemm", | ||||
|         ).blocked_autorange(min_run_time=min_run_time)) | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt= | ||||
|             "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)",  # noqa: E501 | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
|             description="gptq_marlin_gemm", | ||||
|         ).blocked_autorange(min_run_time=min_run_time)) | ||||
|  | ||||
|     if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS | ||||
|             and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): | ||||
|         results.append( | ||||
|             benchmark.Timer( | ||||
|                 stmt= | ||||
|                 "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)",  # noqa: E501 | ||||
|                 globals=globals, | ||||
|                 label=label, | ||||
|                 sub_label=sub_label, | ||||
|                 description="gptq_marlin_24_gemm", | ||||
|             ).blocked_autorange(min_run_time=min_run_time)) | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt= | ||||
|             "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)",  # noqa: E501 | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
|             description="gptq_marlin_repack", | ||||
|         ).blocked_autorange(min_run_time=min_run_time)) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     print("Benchmarking models:") | ||||
|     for i, model in enumerate(args.models): | ||||
|         print(f"[{i}]  {model}") | ||||
|  | ||||
|     results = [] | ||||
|  | ||||
|     for model in args.models: | ||||
|         for layer in WEIGHT_SHAPES[model]: | ||||
|             size_k = layer[0] | ||||
|             size_n = layer[1] | ||||
|  | ||||
|             if len(args.limit_k) > 0 and size_k not in args.limit_k: | ||||
|                 continue | ||||
|  | ||||
|             if len(args.limit_n) > 0 and size_n not in args.limit_n: | ||||
|                 continue | ||||
|  | ||||
|             for act_order in ACT_ORDER_OPTS: | ||||
|                 if len(args.limit_act_order | ||||
|                        ) > 0 and act_order not in args.limit_act_order: | ||||
|                     continue | ||||
|  | ||||
|                 for is_k_full in K_FULL_OPTS: | ||||
|                     if len(args.limit_k_full | ||||
|                            ) > 0 and is_k_full not in args.limit_k_full: | ||||
|                         continue | ||||
|  | ||||
|                     for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: | ||||
|                         if len(args.limit_num_bits | ||||
|                                ) > 0 and num_bits not in args.limit_num_bits: | ||||
|                             continue | ||||
|  | ||||
|                         for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: | ||||
|                             if len( | ||||
|                                     args.limit_group_size | ||||
|                             ) > 0 and group_size not in args.limit_group_size: | ||||
|                                 continue | ||||
|  | ||||
|                             # For act_order, the group_size must be less than | ||||
|                             # size_k | ||||
|                             if act_order and (group_size == size_k | ||||
|                                               or group_size == -1): | ||||
|                                 continue | ||||
|  | ||||
|                             for size_m in args.batch_sizes: | ||||
|                                 bench_run(results, model, act_order, is_k_full, | ||||
|                                           num_bits, group_size, size_m, size_k, | ||||
|                                           size_n) | ||||
|  | ||||
|     compare = benchmark.Compare(results) | ||||
|     compare.print() | ||||
|  | ||||
|  | ||||
| # For quick benchmarking use: | ||||
| #   python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 | ||||
| # | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Benchmark Marlin across specified models/shapes/batches") | ||||
|     parser.add_argument( | ||||
|         "--models", | ||||
|         nargs="+", | ||||
|         type=str, | ||||
|         default=DEFAULT_MODELS, | ||||
|         choices=WEIGHT_SHAPES.keys(), | ||||
|     ) | ||||
|     parser.add_argument("--batch-sizes", | ||||
|                         nargs="+", | ||||
|                         type=int, | ||||
|                         default=DEFAULT_BATCH_SIZES) | ||||
|     parser.add_argument("--limit-k", nargs="+", type=int, default=[]) | ||||
|     parser.add_argument("--limit-n", nargs="+", type=int, default=[]) | ||||
|     parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) | ||||
|     parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) | ||||
|     parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) | ||||
|     parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
| @ -11,25 +11,36 @@ from tqdm import tqdm | ||||
| from vllm.model_executor.layers.fused_moe import (fused_moe, | ||||
|                                                   get_config_file_name) | ||||
|  | ||||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||||
|  | ||||
|  | ||||
| def main(dtype: str): | ||||
| def main(model, tp_size, gpu, dtype: str): | ||||
|     os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) | ||||
|     method = fused_moe | ||||
|     for bs in [ | ||||
|             1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, | ||||
|             2048, 3072, 4096 | ||||
|     ]: | ||||
|         run_grid(bs, method=method, dtype=dtype) | ||||
|         run_grid(bs, | ||||
|                  model=model, | ||||
|                  method=method, | ||||
|                  gpu=gpu, | ||||
|                  tp_size=tp_size, | ||||
|                  dtype=dtype) | ||||
|  | ||||
|  | ||||
| def run_grid(bs, method, dtype: str): | ||||
|     d_model = 4096 | ||||
| def run_grid(bs, model, method, gpu, tp_size, dtype: str): | ||||
|     if model == '8x7B': | ||||
|         d_model = 4096 | ||||
|         model_intermediate_size = 14336 | ||||
|         num_layers = 32 | ||||
|     elif model == '8x22B': | ||||
|         d_model = 6144 | ||||
|         model_intermediate_size = 16384 | ||||
|         num_layers = 56 | ||||
|     else: | ||||
|         raise ValueError(f'Unsupported Mixtral model {model}') | ||||
|     num_total_experts = 8 | ||||
|     top_k = 2 | ||||
|     tp_size = 2 | ||||
|     model_intermediate_size = 14336 | ||||
|     num_layers = 32 | ||||
|     # tp_size = 2 | ||||
|     num_calls = 100 | ||||
|  | ||||
|     num_warmup_trials = 1 | ||||
| @ -211,5 +222,18 @@ if __name__ == "__main__": | ||||
|         choices=['float8', 'float16'], | ||||
|         help='Data type used for fused_moe kernel computations', | ||||
|     ) | ||||
|     parser.add_argument('--model', | ||||
|                         type=str, | ||||
|                         default='8x7B', | ||||
|                         choices=['8x7B', '8x22B'], | ||||
|                         help='The Mixtral model to benchmark') | ||||
|     parser.add_argument('--tp-size', | ||||
|                         type=int, | ||||
|                         default=2, | ||||
|                         help='Tensor paralleli size') | ||||
|     parser.add_argument('--gpu', | ||||
|                         type=int, | ||||
|                         default=0, | ||||
|                         help="GPU ID for benchmarking") | ||||
|     args = parser.parse_args() | ||||
|     sys.exit(main(args.dtype)) | ||||
|     sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype)) | ||||
|  | ||||
| @ -170,7 +170,7 @@ if __name__ == '__main__': | ||||
|     parser.add_argument("--num-kv-heads", type=int, default=8) | ||||
|     parser.add_argument("--head-size", | ||||
|                         type=int, | ||||
|                         choices=[64, 80, 96, 112, 128, 256], | ||||
|                         choices=[64, 80, 96, 112, 128, 192, 256], | ||||
|                         default=128) | ||||
|     parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) | ||||
|     parser.add_argument("--use-alibi", action="store_true") | ||||
| @ -183,13 +183,11 @@ if __name__ == '__main__': | ||||
|     parser.add_argument( | ||||
|         "--kv-cache-dtype", | ||||
|         type=str, | ||||
|         choices=["auto", "fp8"], | ||||
|         choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], | ||||
|         default="auto", | ||||
|         help= | ||||
|         'Data type for kv cache storage. If "auto", will use model data type. ' | ||||
|         'FP8_E5M2 (without scaling) is only supported on cuda version greater ' | ||||
|         'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' | ||||
|         'common inference criteria.') | ||||
|         help="Data type for kv cache storage. If 'auto', will use model " | ||||
|         "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " | ||||
|         "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") | ||||
|     args = parser.parse_args() | ||||
|     print(args) | ||||
|  | ||||
|  | ||||
| @ -93,7 +93,7 @@ if __name__ == '__main__': | ||||
|     parser.add_argument("--num-heads", type=int, default=8) | ||||
|     parser.add_argument("--head-size", | ||||
|                         type=int, | ||||
|                         choices=[64, 80, 96, 112, 128, 256], | ||||
|                         choices=[64, 80, 96, 112, 128, 192, 256], | ||||
|                         default=128) | ||||
|     parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) | ||||
|     parser.add_argument("--dtype", | ||||
|  | ||||
							
								
								
									
										75
									
								
								benchmarks/kernels/benchmark_shapes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								benchmarks/kernels/benchmark_shapes.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,75 @@ | ||||
| WEIGHT_SHAPES = { | ||||
|     "ideal": [[4 * 256 * 32, 256 * 32]], | ||||
|     "mistralai/Mistral-7B-v0.1/TP1": [ | ||||
|         [4096, 6144], | ||||
|         [4096, 4096], | ||||
|         [4096, 28672], | ||||
|         [14336, 4096], | ||||
|     ], | ||||
|     "mistralai/Mistral-7B-v0.1/TP2": [ | ||||
|         [4096, 3072], | ||||
|         [2048, 4096], | ||||
|         [4096, 14336], | ||||
|         [7168, 4096], | ||||
|     ], | ||||
|     "mistralai/Mistral-7B-v0.1/TP4": [ | ||||
|         [4096, 1536], | ||||
|         [1024, 4096], | ||||
|         [4096, 7168], | ||||
|         [3584, 4096], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-7b-hf/TP1": [ | ||||
|         [4096, 12288], | ||||
|         [4096, 4096], | ||||
|         [4096, 22016], | ||||
|         [11008, 4096], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-7b-hf/TP2": [ | ||||
|         [4096, 6144], | ||||
|         [2048, 4096], | ||||
|         [4096, 11008], | ||||
|         [5504, 4096], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-7b-hf/TP4": [ | ||||
|         [4096, 3072], | ||||
|         [1024, 4096], | ||||
|         [4096, 5504], | ||||
|         [2752, 4096], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-13b-hf/TP1": [ | ||||
|         [5120, 15360], | ||||
|         [5120, 5120], | ||||
|         [5120, 27648], | ||||
|         [13824, 5120], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-13b-hf/TP2": [ | ||||
|         [5120, 7680], | ||||
|         [2560, 5120], | ||||
|         [5120, 13824], | ||||
|         [6912, 5120], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-13b-hf/TP4": [ | ||||
|         [5120, 3840], | ||||
|         [1280, 5120], | ||||
|         [5120, 6912], | ||||
|         [3456, 5120], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-70b-hf/TP1": [ | ||||
|         [8192, 10240], | ||||
|         [8192, 8192], | ||||
|         [8192, 57344], | ||||
|         [28672, 8192], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-70b-hf/TP2": [ | ||||
|         [8192, 5120], | ||||
|         [4096, 8192], | ||||
|         [8192, 28672], | ||||
|         [14336, 8192], | ||||
|     ], | ||||
|     "meta-llama/Llama-2-70b-hf/TP4": [ | ||||
|         [8192, 2560], | ||||
|         [2048, 8192], | ||||
|         [8192, 14336], | ||||
|         [7168, 8192], | ||||
|     ], | ||||
| } | ||||
| @ -4,7 +4,7 @@ PORT=8000 | ||||
| MODEL=$1 | ||||
| TOKENS=$2 | ||||
|  | ||||
| docker run --gpus all --shm-size 1g -p $PORT:80 \ | ||||
| docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ | ||||
|            -v $PWD/data:/data \ | ||||
|            ghcr.io/huggingface/text-generation-inference:1.4.0 \ | ||||
|            --model-id $MODEL \ | ||||
|  | ||||
							
								
								
									
										63
									
								
								benchmarks/overheads/benchmark_hashing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								benchmarks/overheads/benchmark_hashing.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,63 @@ | ||||
| import argparse | ||||
| import cProfile | ||||
| import pstats | ||||
|  | ||||
| from vllm import LLM, SamplingParams | ||||
|  | ||||
| # A very long prompt, total number of tokens is about 15k. | ||||
| LONG_PROMPT = ["You are an expert in large language models, aren't you?" | ||||
|                ] * 1000 | ||||
| LONG_PROMPT = ' '.join(LONG_PROMPT) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     llm = LLM( | ||||
|         model=args.model, | ||||
|         enforce_eager=True, | ||||
|         enable_prefix_caching=True, | ||||
|         tensor_parallel_size=args.tensor_parallel_size, | ||||
|         use_v2_block_manager=args.use_v2_block_manager, | ||||
|     ) | ||||
|  | ||||
|     sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) | ||||
|     profiler = cProfile.Profile() | ||||
|  | ||||
|     print("------warm up------") | ||||
|     for i in range(3): | ||||
|         output = llm.generate(LONG_PROMPT, sampling_params) | ||||
|         print(output[0].outputs[0].text) | ||||
|  | ||||
|     print("------start generating------") | ||||
|     for i in range(3): | ||||
|         profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', | ||||
|                         globals(), locals()) | ||||
|  | ||||
|     # analyze the runtime of hashing function | ||||
|     stats = pstats.Stats(profiler) | ||||
|     stats.sort_stats('cumulative') | ||||
|     total_time = 0 | ||||
|     total_calls = 0 | ||||
|     for func in stats.stats: | ||||
|         if 'hash_of_block' in func[2]: | ||||
|             total_time = stats.stats[func][3] | ||||
|             total_calls = stats.stats[func][0] | ||||
|     percentage = (total_time / stats.total_tt) * 100 | ||||
|     print(f"Hashing took {total_time:.2f} seconds," | ||||
|           f"{percentage:.2f}% of the total runtime.") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description='Benchmark the performance of hashing function in' | ||||
|         'automatic prefix caching.') | ||||
|     parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') | ||||
|     parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) | ||||
|     parser.add_argument('--output-len', type=int, default=10) | ||||
|     parser.add_argument('--enable-prefix-caching', | ||||
|                         action='store_true', | ||||
|                         help='enable prefix caching') | ||||
|     parser.add_argument('--use-v2-block-manager', | ||||
|                         action='store_true', | ||||
|                         help='Use BlockSpaceMangerV2') | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
| @ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) | ||||
|       "Failed to determine torch nvcc compiler flags") | ||||
|  | ||||
|     if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) | ||||
|       list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") | ||||
|       list(APPEND GPU_FLAGS "-DENABLE_FP8") | ||||
|     endif() | ||||
|     if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) | ||||
|       list(REMOVE_ITEM GPU_FLAGS | ||||
| @ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) | ||||
|  | ||||
|     list(APPEND GPU_FLAGS | ||||
|       "-DUSE_ROCM" | ||||
|       "-DENABLE_FP8_E4M3" | ||||
|       "-DENABLE_FP8" | ||||
|       "-U__HIP_NO_HALF_CONVERSIONS__" | ||||
|       "-U__HIP_NO_HALF_OPERATORS__" | ||||
|       "-fno-gpu-rdc") | ||||
|  | ||||
| @ -10,11 +10,11 @@ | ||||
| namespace vllm { | ||||
|  | ||||
| // Activation and gating kernel template. | ||||
| template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> | ||||
| template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> | ||||
| __global__ void act_and_mul_kernel( | ||||
|   scalar_t* __restrict__ out,               // [..., d] | ||||
|   const scalar_t* __restrict__ input,       // [..., 2, d] | ||||
|   const int d) { | ||||
|     scalar_t* __restrict__ out,          // [..., d] | ||||
|     const scalar_t* __restrict__ input,  // [..., 2, d] | ||||
|     const int d) { | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { | ||||
|     const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); | ||||
| @ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| __device__ __forceinline__ T silu_kernel(const T& x) { | ||||
|   // x * sigmoid(x) | ||||
|   return (T) (((float) x) / (1.0f + expf((float) -x))); | ||||
|   return (T)(((float)x) / (1.0f + expf((float)-x))); | ||||
| } | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| __device__ __forceinline__ T gelu_kernel(const T& x) { | ||||
|   // Equivalent to PyTorch GELU with 'none' approximation. | ||||
|   // Refer to: | ||||
|   // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 | ||||
|   const float f = (float) x; | ||||
|   const float f = (float)x; | ||||
|   constexpr float ALPHA = M_SQRT1_2; | ||||
|   return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); | ||||
|   return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); | ||||
| } | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { | ||||
|   // Equivalent to PyTorch GELU with 'tanh' approximation. | ||||
|   // Refer to: | ||||
|   // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 | ||||
|   const float f = (float) x; | ||||
|   const float f = (float)x; | ||||
|   constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; | ||||
|   constexpr float KAPPA = 0.044715; | ||||
|   float x_cube = f * f * f; | ||||
|   float inner = BETA * (f + KAPPA * x_cube); | ||||
|   return (T) (0.5f * f * (1.0f + ::tanhf(inner))); | ||||
|   return (T)(0.5f * f * (1.0f + ::tanhf(inner))); | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| // Launch activation and gating kernel. | ||||
| #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                                             \ | ||||
|   int d = input.size(-1) / 2;                                                             \ | ||||
|   int64_t num_tokens = input.numel() / input.size(-1);                                    \ | ||||
|   dim3 grid(num_tokens);                                                                  \ | ||||
|   dim3 block(std::min(d, 1024));                                                          \ | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \ | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \ | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(                                                           \ | ||||
|     input.scalar_type(),                                                                  \ | ||||
|     "act_and_mul_kernel",                                                                 \ | ||||
|     [&] {                                                                                 \ | ||||
|       vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(   \ | ||||
|         out.data_ptr<scalar_t>(),                                                         \ | ||||
|         input.data_ptr<scalar_t>(),                                                       \ | ||||
|         d);                                                                               \ | ||||
|     }); | ||||
| #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                            \ | ||||
|   int d = input.size(-1) / 2;                                            \ | ||||
|   int64_t num_tokens = input.numel() / input.size(-1);                   \ | ||||
|   dim3 grid(num_tokens);                                                 \ | ||||
|   dim3 block(std::min(d, 1024));                                         \ | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \ | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \ | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(                                          \ | ||||
|       input.scalar_type(), "act_and_mul_kernel", [&] {                   \ | ||||
|         vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>>             \ | ||||
|             <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \ | ||||
|                                          input.data_ptr<scalar_t>(), d); \ | ||||
|       }); | ||||
|  | ||||
| void silu_and_mul( | ||||
|   torch::Tensor& out,      // [..., d] | ||||
|   torch::Tensor& input)    // [..., 2 * d] | ||||
| void silu_and_mul(torch::Tensor& out,    // [..., d] | ||||
|                   torch::Tensor& input)  // [..., 2 * d] | ||||
| { | ||||
|   LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); | ||||
| } | ||||
|  | ||||
| void gelu_and_mul( | ||||
|   torch::Tensor& out,      // [..., d] | ||||
|   torch::Tensor& input)    // [..., 2 * d] | ||||
| void gelu_and_mul(torch::Tensor& out,    // [..., d] | ||||
|                   torch::Tensor& input)  // [..., 2 * d] | ||||
| { | ||||
|   LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); | ||||
| } | ||||
|  | ||||
| void gelu_tanh_and_mul( | ||||
|   torch::Tensor& out,      // [..., d] | ||||
|   torch::Tensor& input)    // [..., 2 * d] | ||||
| void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d] | ||||
|                        torch::Tensor& input)  // [..., 2 * d] | ||||
| { | ||||
|   LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); | ||||
| } | ||||
| @ -96,11 +90,11 @@ void gelu_tanh_and_mul( | ||||
| namespace vllm { | ||||
|  | ||||
| // Element-wise activation kernel template. | ||||
| template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> | ||||
| template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> | ||||
| __global__ void activation_kernel( | ||||
|   scalar_t* __restrict__ out,               // [..., d] | ||||
|   const scalar_t* __restrict__ input,       // [..., d] | ||||
|   const int d) { | ||||
|     scalar_t* __restrict__ out,          // [..., d] | ||||
|     const scalar_t* __restrict__ input,  // [..., d] | ||||
|     const int d) { | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { | ||||
|     const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); | ||||
| @ -108,54 +102,49 @@ __global__ void activation_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| // Launch element-wise activation kernel. | ||||
| #define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \ | ||||
|   int d = input.size(-1);                                                                 \ | ||||
|   int64_t num_tokens = input.numel() / d;                                                 \ | ||||
|   dim3 grid(num_tokens);                                                                  \ | ||||
|   dim3 block(std::min(d, 1024));                                                          \ | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \ | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \ | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(                                                           \ | ||||
|     input.scalar_type(),                                                                  \ | ||||
|     "activation_kernel",                                                                  \ | ||||
|     [&] {                                                                                 \ | ||||
|       vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(    \ | ||||
|         out.data_ptr<scalar_t>(),                                                         \ | ||||
|         input.data_ptr<scalar_t>(),                                                       \ | ||||
|         d);                                                                               \ | ||||
|     }); | ||||
| #define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                       \ | ||||
|   int d = input.size(-1);                                                      \ | ||||
|   int64_t num_tokens = input.numel() / d;                                      \ | ||||
|   dim3 grid(num_tokens);                                                       \ | ||||
|   dim3 block(std::min(d, 1024));                                               \ | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \ | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \ | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ | ||||
|     vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>                        \ | ||||
|         <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),                 \ | ||||
|                                      input.data_ptr<scalar_t>(), d);           \ | ||||
|   }); | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| __device__ __forceinline__ T gelu_new_kernel(const T& x) { | ||||
|   const float x3 = (float) (x * x * x); | ||||
|   const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); | ||||
|   return ((T) 0.5) * x * (((T) 1.0) + t); | ||||
|   const float x3 = (float)(x * x * x); | ||||
|   const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); | ||||
|   return ((T)0.5) * x * (((T)1.0) + t); | ||||
| } | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| __device__ __forceinline__ T gelu_fast_kernel(const T& x) { | ||||
|   const float f = (float) x; | ||||
|   const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); | ||||
|   return ((T) 0.5) * x * (((T) 1.0) + t); | ||||
|   const float f = (float)x; | ||||
|   const T t = | ||||
|       (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); | ||||
|   return ((T)0.5) * x * (((T)1.0) + t); | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| void gelu_new( | ||||
|   torch::Tensor& out,     // [..., d] | ||||
|   torch::Tensor& input)   // [..., d] | ||||
| void gelu_new(torch::Tensor& out,    // [..., d] | ||||
|               torch::Tensor& input)  // [..., d] | ||||
| { | ||||
|   LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); | ||||
| } | ||||
|  | ||||
| void gelu_fast( | ||||
|   torch::Tensor& out,     // [..., d] | ||||
|   torch::Tensor& input)   // [..., d] | ||||
| void gelu_fast(torch::Tensor& out,    // [..., d] | ||||
|                torch::Tensor& input)  // [..., d] | ||||
| { | ||||
|   LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); | ||||
| } | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| /* | ||||
|  * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Adapted from | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Copyright (c) 2023, The vLLM team. | ||||
|  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved. | ||||
|  * | ||||
| @ -22,31 +23,31 @@ | ||||
| namespace vllm { | ||||
|  | ||||
| // A vector type to store Q, K, V elements. | ||||
| template<typename T, int VEC_SIZE> | ||||
| template <typename T, int VEC_SIZE> | ||||
| struct Vec {}; | ||||
|  | ||||
| // A vector type to store FP32 accumulators. | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| struct FloatVec {}; | ||||
|  | ||||
| // Template vector operations. | ||||
| template<typename Acc, typename A, typename B> | ||||
| template <typename Acc, typename A, typename B> | ||||
| inline __device__ Acc mul(A a, B b); | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| inline __device__ float sum(T v); | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| inline __device__ float dot(T a, T b) { | ||||
|   return sum(mul<T, T, T>(a, b)); | ||||
| } | ||||
|  | ||||
| template<typename A, typename T> | ||||
| template <typename A, typename T> | ||||
| inline __device__ float dot(T a, T b) { | ||||
|   return sum(mul<A, T, T>(a, b)); | ||||
| } | ||||
|  | ||||
| template<typename T> | ||||
| template <typename T> | ||||
| inline __device__ void zero(T& dst) { | ||||
|   constexpr int WORDS = sizeof(T) / 4; | ||||
|   union { | ||||
| @ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { | ||||
|   dst = tmp.raw; | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,5 +1,6 @@ | ||||
| /* | ||||
|  * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * Adapted from | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * Copyright (c) 2023, The vLLM team. | ||||
|  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved. | ||||
|  * | ||||
| @ -26,7 +27,7 @@ | ||||
| namespace vllm { | ||||
|  | ||||
| // Q*K^T operation. | ||||
| template<int THREAD_GROUP_SIZE, typename Vec, int N> | ||||
| template <int THREAD_GROUP_SIZE, typename Vec, int N> | ||||
| inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { | ||||
|   using A_vec = typename FloatVec<Vec>::Type; | ||||
|   // Compute the parallel products for Q*K^T (treat vector lanes separately). | ||||
| @ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { | ||||
|   return qk; | ||||
| } | ||||
|  | ||||
| template<typename T, int THREAD_GROUP_SIZE> | ||||
| template <typename T, int THREAD_GROUP_SIZE> | ||||
| struct Qk_dot { | ||||
|   template<typename Vec, int N> | ||||
|   template <typename Vec, int N> | ||||
|   static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { | ||||
|     return qk_dot_<THREAD_GROUP_SIZE>(q, k); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| /* | ||||
|  * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Adapted from | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * and | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Copyright (c) 2023, The vLLM team. | ||||
|  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved. | ||||
|  * | ||||
| @ -28,8 +30,8 @@ | ||||
|   #include <hip/hip_bf16.h> | ||||
|   #include <hip/hip_fp16.h> | ||||
|  | ||||
|   typedef __hip_bfloat162 __nv_bfloat162; | ||||
|   typedef __hip_bfloat16 __nv_bfloat16; | ||||
| typedef __hip_bfloat162 __nv_bfloat162; | ||||
| typedef __hip_bfloat16 __nv_bfloat16; | ||||
| #endif | ||||
|  | ||||
| #include <stdint.h> | ||||
| @ -50,37 +52,37 @@ struct bf16_8_t { | ||||
| }; | ||||
|  | ||||
| // BF16 vector types for Q, K, V. | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<__nv_bfloat16, 1> { | ||||
|   using Type = __nv_bfloat16; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<__nv_bfloat16, 2> { | ||||
|   using Type = __nv_bfloat162; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<__nv_bfloat16, 4> { | ||||
|   using Type = bf16_4_t; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<__nv_bfloat16, 8> { | ||||
|   using Type = bf16_8_t; | ||||
| }; | ||||
|  | ||||
| // FP32 accumulator vector types corresponding to Vec. | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<__nv_bfloat16> { | ||||
|   using Type = float; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<__nv_bfloat162> { | ||||
|   using Type = float2; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<bf16_4_t> { | ||||
|   using Type = Float4_; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<bf16_8_t> { | ||||
|   using Type = Float8_; | ||||
| }; | ||||
| @ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { | ||||
|   assert(false); | ||||
| #else | ||||
|   #ifndef USE_ROCM | ||||
|     return a + b; | ||||
|   return a + b; | ||||
|   #else | ||||
|     return __hadd(a, b); | ||||
|   return __hadd(a, b); | ||||
|   #endif | ||||
| #endif | ||||
| } | ||||
| @ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { | ||||
| } | ||||
|  | ||||
| // Vector multiplication. | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|   assert(false); | ||||
| @ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|   assert(false); | ||||
| @ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { | ||||
|   return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { | ||||
|   bf16_4_t c; | ||||
|   c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); | ||||
| @ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { | ||||
|   __nv_bfloat162 s = bf162bf162(a); | ||||
|   bf16_4_t c; | ||||
| @ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { | ||||
|   bf16_8_t c; | ||||
|   c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); | ||||
| @ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { | ||||
|   __nv_bfloat162 s = bf162bf162(a); | ||||
|   bf16_8_t c; | ||||
| @ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { | ||||
|   float fa = __bfloat162float(a); | ||||
|   float fb = __bfloat162float(b); | ||||
|   return fa * fb; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { | ||||
|   float2 fa = bf1622float2(a); | ||||
|   float2 fb = bf1622float2(b); | ||||
|   return mul<float2, float2, float2>(fa, fb); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { | ||||
|   return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { | ||||
|   Float4_ fc; | ||||
|   fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); | ||||
| @ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { | ||||
|   return fc; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { | ||||
|   __nv_bfloat162 s = bf162bf162(a); | ||||
|   Float4_ fc; | ||||
| @ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { | ||||
|   return fc; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { | ||||
|   Float8_ fc; | ||||
|   fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); | ||||
| @ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { | ||||
|   return fc; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { | ||||
|   __nv_bfloat162 s = bf162bf162(a); | ||||
|   Float8_ fc; | ||||
| @ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { | ||||
| } | ||||
|  | ||||
| // Vector fused multiply-add. | ||||
| inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { | ||||
| inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, | ||||
|                                      __nv_bfloat162 c) { | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|   assert(false); | ||||
| #else | ||||
| @ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf | ||||
| #endif | ||||
| } | ||||
|  | ||||
| inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { | ||||
| inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, | ||||
|                                      __nv_bfloat162 c) { | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|   assert(false); | ||||
| #else | ||||
| @ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { | ||||
| } | ||||
|  | ||||
| // Vector sum. | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(__nv_bfloat16 v) { | ||||
|   return __bfloat162float(v); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(__nv_bfloat162 v) { | ||||
|   float2 vf = bf1622float2(v); | ||||
|   return vf.x + vf.y; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(bf16_4_t v) { | ||||
|   return sum(v.x) + sum(v.y); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(bf16_8_t v) { | ||||
|   return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); | ||||
| } | ||||
| @ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { | ||||
| #endif | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| /* | ||||
|  * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Adapted from | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * and | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Copyright (c) 2023, The vLLM team. | ||||
|  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved. | ||||
|  * | ||||
| @ -30,37 +32,37 @@ | ||||
| namespace vllm { | ||||
|  | ||||
| // FP16 vector types for Q, K, V. | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint16_t, 1> { | ||||
|   using Type = uint16_t; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint16_t, 2> { | ||||
|   using Type = uint32_t; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint16_t, 4> { | ||||
|   using Type = uint2; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint16_t, 8> { | ||||
|   using Type = uint4; | ||||
| }; | ||||
|  | ||||
| // FP32 accumulator vector types corresponding to Vec. | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<uint16_t> { | ||||
|   using Type = float; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<uint32_t> { | ||||
|   using Type = float2; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<uint2> { | ||||
|   using Type = Float4_; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<uint4> { | ||||
|   using Type = Float8_; | ||||
| }; | ||||
| @ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { | ||||
|   return b; | ||||
| #else | ||||
|   union { | ||||
|    uint32_t u32; | ||||
|    uint16_t u16[2]; | ||||
|     uint32_t u32; | ||||
|     uint16_t u16[2]; | ||||
|   } tmp; | ||||
|   tmp.u16[0] = a; | ||||
|   tmp.u16[1] = a; | ||||
| @ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { | ||||
|   } tmp; | ||||
| #ifndef USE_ROCM | ||||
|   #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||||
|     asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); | ||||
|   asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" | ||||
|                : "=r"(tmp.u32) | ||||
|                : "f"(f.y), "f"(f.x)); | ||||
|   #else | ||||
|     asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); | ||||
|     asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); | ||||
|   asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); | ||||
|   asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); | ||||
|   #endif | ||||
| #else | ||||
|   tmp.u16[0] = float_to_half(f.x); | ||||
| @ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { | ||||
| } | ||||
|  | ||||
| // Vector multiplication. | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ uint16_t mul(uint16_t a, uint16_t b) { | ||||
|   uint16_t c; | ||||
| #ifndef USE_ROCM | ||||
| @ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ uint32_t mul(uint32_t a, uint32_t b) { | ||||
|   uint32_t c; | ||||
| #ifndef USE_ROCM | ||||
| @ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ uint32_t mul(uint16_t a, uint32_t b) { | ||||
|   return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ uint2 mul(uint2 a, uint2 b) { | ||||
|   uint2 c; | ||||
|   c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); | ||||
| @ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ uint2 mul(uint16_t a, uint2 b) { | ||||
|   uint32_t s = h0_h0(a); | ||||
|   uint2 c; | ||||
| @ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ uint4 mul(uint4 a, uint4 b) { | ||||
|   uint4 c; | ||||
|   c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); | ||||
| @ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ uint4 mul(uint16_t a, uint4 b) { | ||||
|   uint32_t s = h0_h0(a); | ||||
|   uint4 c; | ||||
| @ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float mul(uint16_t a, uint16_t b) { | ||||
|   float fa = half_to_float(a); | ||||
|   float fb = half_to_float(b); | ||||
|   return fa * fb; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float2 mul(uint32_t a, uint32_t b) { | ||||
|   float2 fa = half2_to_float2(a); | ||||
|   float2 fb = half2_to_float2(b); | ||||
|   return mul<float2, float2, float2>(fa, fb); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float2 mul(uint16_t a, uint32_t b) { | ||||
|   return mul<float2, uint32_t, uint32_t>(h0_h0(a), b); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float4_ mul(uint2 a, uint2 b) { | ||||
|   Float4_ fc; | ||||
|   fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); | ||||
| @ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { | ||||
|   return fc; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float4_ mul(uint16_t a, uint2 b) { | ||||
|   uint32_t s = h0_h0(a); | ||||
|   Float4_ fc; | ||||
| @ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { | ||||
|   return fc; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float8_ mul(uint4 a, uint4 b) { | ||||
|   Float8_ fc; | ||||
|   fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); | ||||
| @ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { | ||||
|   return fc; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ Float8_ mul(uint16_t a, uint4 b) { | ||||
|   uint32_t s = h0_h0(a); | ||||
|   Float8_ fc; | ||||
| @ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { | ||||
| inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { | ||||
|   uint32_t d; | ||||
| #ifndef USE_ROCM | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                : "=r"(d) | ||||
|                : "r"(a), "r"(b), "r"(c)); | ||||
| #else | ||||
|   asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); | ||||
|   asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" | ||||
|                : "=v"(d) | ||||
|                : "v"(a), "v"(b), "v"(c)); | ||||
| #endif | ||||
|   return d; | ||||
| } | ||||
| @ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { | ||||
| } | ||||
|  | ||||
| // Vector sum. | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(uint16_t v) { | ||||
|   return half_to_float(v); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(uint32_t v) { | ||||
|   float2 tmp = half2_to_float2(v); | ||||
|   return tmp.x + tmp.y; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(uint2 v) { | ||||
|   uint32_t c = add(v.x, v.y); | ||||
|   return sum(c); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(uint4 v) { | ||||
|   uint32_t c = add(v.x, v.y); | ||||
|   c = add(c, v.z); | ||||
| @ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { | ||||
| } | ||||
|  | ||||
| // From float16 to float32. | ||||
| inline __device__ float to_float(uint16_t u) { | ||||
|   return half_to_float(u); | ||||
| } | ||||
| inline __device__ float to_float(uint16_t u) { return half_to_float(u); } | ||||
|  | ||||
| inline __device__ float2 to_float(uint32_t u) { | ||||
|   return half2_to_float2(u); | ||||
| } | ||||
| inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } | ||||
|  | ||||
| inline __device__ Float4_ to_float(uint2 u) { | ||||
|   Float4_ tmp; | ||||
| @ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { | ||||
| } | ||||
|  | ||||
| // Zero-out a variable. | ||||
| inline __device__ void zero(uint16_t& dst) { | ||||
|   dst = uint16_t(0); | ||||
| } | ||||
| inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| /* | ||||
|  * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Adapted from | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp | ||||
|  * and | ||||
|  * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h | ||||
|  * Copyright (c) 2023, The vLLM team. | ||||
|  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved. | ||||
|  * | ||||
| @ -38,37 +40,35 @@ struct Float8_ { | ||||
| }; | ||||
|  | ||||
| // FP32 vector types for Q, K, V. | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<float, 1> { | ||||
|   using Type = float; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<float, 2> { | ||||
|   using Type = float2; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<float, 4> { | ||||
|   using Type = float4; | ||||
| }; | ||||
|  | ||||
| // FP32 accumulator vector types corresponding to Vec. | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<float> { | ||||
|   using Type = float; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<float2> { | ||||
|   using Type = float2; | ||||
| }; | ||||
| template<> | ||||
| template <> | ||||
| struct FloatVec<float4> { | ||||
|   using Type = float4; | ||||
| }; | ||||
|  | ||||
| // Vector addition. | ||||
| inline __device__ float add(float a, float b) { | ||||
|   return a + b; | ||||
| } | ||||
| inline __device__ float add(float a, float b) { return a + b; } | ||||
|  | ||||
| inline __device__ float2 add(float2 a, float2 b) { | ||||
|   float2 c; | ||||
| @ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { | ||||
| } | ||||
|  | ||||
| // Vector multiplication. | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float mul<float, float>(float a, float b) { | ||||
|   return a * b; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float2 mul(float2 a, float2 b) { | ||||
|   float2 c; | ||||
|   c.x = a.x * b.x; | ||||
| @ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float2 mul(float a, float2 b) { | ||||
|   float2 c; | ||||
|   c.x = a * b.x; | ||||
| @ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float4 mul(float4 a, float4 b) { | ||||
|   float4 c; | ||||
|   c.x = a.x * b.x; | ||||
| @ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { | ||||
|   return c; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float4 mul(float a, float4 b) { | ||||
|   float4 c; | ||||
|   c.x = a * b.x; | ||||
| @ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { | ||||
| } | ||||
|  | ||||
| // Vector fused multiply-add. | ||||
| inline __device__ float fma(float a, float b, float c) { | ||||
|   return a * b + c; | ||||
| } | ||||
| inline __device__ float fma(float a, float b, float c) { return a * b + c; } | ||||
|  | ||||
| inline __device__ float2 fma(float2 a, float2 b, float2 c) { | ||||
|   float2 d; | ||||
| @ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { | ||||
| } | ||||
|  | ||||
| // Vector sum. | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(float v) { | ||||
|   return v; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(float2 v) { | ||||
|   return v.x + v.y; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(float4 v) { | ||||
|   return v.x + v.y + v.z + v.w; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(Float4_ v) { | ||||
|   return v.x.x + v.x.y + v.y.x + v.y.y; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| inline __device__ float sum(Float8_ v) { | ||||
|   return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; | ||||
| } | ||||
|  | ||||
| // Vector dot product. | ||||
| inline __device__ float dot(float a, float b) { | ||||
|   return a * b; | ||||
| } | ||||
| inline __device__ float dot(float a, float b) { return a * b; } | ||||
|  | ||||
| inline __device__ float dot(float2 a, float2 b) { | ||||
|   float2 c = mul<float2, float2, float2>(a, b); | ||||
| @ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { | ||||
| } | ||||
|  | ||||
| // From float to float. | ||||
| inline __device__ void from_float(float& dst, float src) { | ||||
|   dst = src; | ||||
| } | ||||
| inline __device__ void from_float(float& dst, float src) { dst = src; } | ||||
|  | ||||
| inline __device__ void from_float(float2& dst, float2 src) { | ||||
|   dst = src; | ||||
| } | ||||
| inline __device__ void from_float(float2& dst, float2 src) { dst = src; } | ||||
|  | ||||
| inline __device__ void from_float(float4& dst, float4 src) { | ||||
|   dst = src; | ||||
| } | ||||
| inline __device__ void from_float(float4& dst, float4 src) { dst = src; } | ||||
|  | ||||
| // From float to float. | ||||
| inline __device__ float to_float(float u) { | ||||
|   return u; | ||||
| } | ||||
| inline __device__ float to_float(float u) { return u; } | ||||
|  | ||||
| inline __device__ float2 to_float(float2 u) { | ||||
|   return u; | ||||
| } | ||||
| inline __device__ float2 to_float(float2 u) { return u; } | ||||
|  | ||||
| inline __device__ float4 to_float(float4 u) { | ||||
|   return u; | ||||
| } | ||||
| inline __device__ float4 to_float(float4 u) { return u; } | ||||
|  | ||||
| inline __device__ Float4_ to_float(Float4_ u) { | ||||
|   return u; | ||||
| } | ||||
| inline __device__ Float4_ to_float(Float4_ u) { return u; } | ||||
|  | ||||
| inline __device__ Float8_ to_float(Float8_ u) { | ||||
|   return u; | ||||
| } | ||||
| inline __device__ Float8_ to_float(Float8_ u) { return u; } | ||||
|  | ||||
| // Zero-out a variable. | ||||
| inline __device__ void zero(float& dst) { | ||||
|   dst = 0.f; | ||||
| } | ||||
| inline __device__ void zero(float& dst) { dst = 0.f; } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| @ -3,33 +3,39 @@ | ||||
| #include "attention_generic.cuh" | ||||
|  | ||||
| #include <stdint.h> | ||||
| #ifdef ENABLE_FP8_E5M2 | ||||
| #include <cuda_fp8.h> | ||||
| #endif | ||||
| #ifdef ENABLE_FP8 | ||||
|   #ifndef USE_ROCM | ||||
|     #include <cuda_fp8.h> | ||||
|   #endif  // USE_ROCM | ||||
| #endif    // ENABLE_FP8 | ||||
|  | ||||
| namespace vllm { | ||||
| #if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) | ||||
|  | ||||
| enum class Fp8KVCacheDataType { | ||||
|   kAuto = 0, | ||||
|   kFp8E4M3 = 1, | ||||
|   kFp8E5M2 = 2, | ||||
| }; | ||||
|  | ||||
| // fp8 vector types for quantization of kv cache | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint8_t, 1> { | ||||
|     using Type = uint8_t; | ||||
|   using Type = uint8_t; | ||||
| }; | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint8_t, 2> { | ||||
|     using Type = uint16_t; | ||||
|   using Type = uint16_t; | ||||
| }; | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint8_t, 4> { | ||||
|     using Type = uint32_t; | ||||
|   using Type = uint32_t; | ||||
| }; | ||||
|  | ||||
| template<> | ||||
| template <> | ||||
| struct Vec<uint8_t, 8> { | ||||
|     using Type = uint2; | ||||
|   using Type = uint2; | ||||
| }; | ||||
| #endif // ENABLE_FP8_E5M2 | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
							
								
								
									
										42
									
								
								csrc/cache.h
									
									
									
									
									
								
							
							
						
						
									
										42
									
								
								csrc/cache.h
									
									
									
									
									
								
							| @ -5,34 +5,24 @@ | ||||
| #include <map> | ||||
| #include <vector> | ||||
|  | ||||
| void swap_blocks( | ||||
|   torch::Tensor& src, | ||||
|   torch::Tensor& dst, | ||||
|   const std::map<int64_t, int64_t>& block_mapping); | ||||
| void swap_blocks(torch::Tensor& src, torch::Tensor& dst, | ||||
|                  const torch::Tensor& block_mapping); | ||||
|  | ||||
| void copy_blocks( | ||||
|   std::vector<torch::Tensor>& key_caches, | ||||
|   std::vector<torch::Tensor>& value_caches, | ||||
|   const std::map<int64_t, std::vector<int64_t>>& block_mapping); | ||||
| void copy_blocks(std::vector<torch::Tensor>& key_caches, | ||||
|                  std::vector<torch::Tensor>& value_caches, | ||||
|                  const torch::Tensor& block_mapping); | ||||
|  | ||||
| void reshape_and_cache( | ||||
|   torch::Tensor& key, | ||||
|   torch::Tensor& value, | ||||
|   torch::Tensor& key_cache, | ||||
|   torch::Tensor& value_cache, | ||||
|   torch::Tensor& slot_mapping, | ||||
|   const std::string& kv_cache_dtype, | ||||
|   const float kv_scale); | ||||
| void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, | ||||
|                        torch::Tensor& key_cache, torch::Tensor& value_cache, | ||||
|                        torch::Tensor& slot_mapping, | ||||
|                        const std::string& kv_cache_dtype, const float kv_scale); | ||||
|  | ||||
| void reshape_and_cache_flash( | ||||
|   torch::Tensor& key, | ||||
|   torch::Tensor& value, | ||||
|   torch::Tensor& key_cache, | ||||
|   torch::Tensor& value_cache, | ||||
|   torch::Tensor& slot_mapping, | ||||
|   const std::string& kv_cache_dtype); | ||||
| void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, | ||||
|                              torch::Tensor& key_cache, | ||||
|                              torch::Tensor& value_cache, | ||||
|                              torch::Tensor& slot_mapping, | ||||
|                              const std::string& kv_cache_dtype); | ||||
|  | ||||
| // Just for unittest | ||||
| void convert_fp8( | ||||
|   torch::Tensor& src_cache, | ||||
|   torch::Tensor& dst_cache); | ||||
| void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, | ||||
|                  const float scale, const std::string& kv_cache_dtype); | ||||
|  | ||||
| @ -4,10 +4,11 @@ | ||||
|  | ||||
| #include "cuda_compat.h" | ||||
| #include "dispatch_utils.h" | ||||
| #if defined(ENABLE_FP8_E5M2) | ||||
| #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" | ||||
| #elif defined(ENABLE_FP8_E4M3) | ||||
| #include "quantization/fp8/amd_detail/quant_utils.cuh" | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|   #include "quantization/fp8/amd/quant_utils.cuh" | ||||
| #else | ||||
|   #include "quantization/fp8/nvidia/quant_utils.cuh" | ||||
| #endif | ||||
|  | ||||
| #include <algorithm> | ||||
| @ -17,20 +18,17 @@ | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|   #include <hip/hip_bf16.h> | ||||
|   typedef __hip_bfloat16 __nv_bfloat16; | ||||
| typedef __hip_bfloat16 __nv_bfloat16; | ||||
| #endif | ||||
|  | ||||
| void swap_blocks( | ||||
|   torch::Tensor& src, | ||||
|   torch::Tensor& dst, | ||||
|   const std::map<int64_t, int64_t>& block_mapping) { | ||||
| void swap_blocks(torch::Tensor& src, torch::Tensor& dst, | ||||
|                  const torch::Tensor& block_mapping) { | ||||
|   torch::Device src_device = src.device(); | ||||
|   torch::Device dst_device = dst.device(); | ||||
|   cudaMemcpyKind memcpy_type; | ||||
|   if (src_device.is_cuda() && dst_device.is_cuda()) { | ||||
|     TORCH_CHECK( | ||||
|       src_device.index() == dst_device.index(), | ||||
|       "src and dst must be on the same GPU"); | ||||
|     TORCH_CHECK(src_device.index() == dst_device.index(), | ||||
|                 "src and dst must be on the same GPU"); | ||||
|     memcpy_type = cudaMemcpyDeviceToDevice; | ||||
|   } else if (src_device.is_cuda() && dst_device.is_cpu()) { | ||||
|     memcpy_type = cudaMemcpyDeviceToHost; | ||||
| @ -40,41 +38,44 @@ void swap_blocks( | ||||
|     TORCH_CHECK(false, "Invalid device combination"); | ||||
|   } | ||||
|  | ||||
|   char *src_ptr = static_cast<char*>(src.data_ptr()); | ||||
|   char *dst_ptr = static_cast<char*>(dst.data_ptr()); | ||||
|   // NOTE(youkaichao): keep in mind that `block_mapping` should be | ||||
|   // a cpu tensor, otherwise every `item` call will require a gpu-cpu | ||||
|   // synchronization. | ||||
|   TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); | ||||
|  | ||||
|   char* src_ptr = static_cast<char*>(src.data_ptr()); | ||||
|   char* dst_ptr = static_cast<char*>(dst.data_ptr()); | ||||
|  | ||||
|   const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard( | ||||
|       src_device.is_cuda() ? src_device : dst_device); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   // NOTE(woosuk): This can be slow if the number of blocks is large. | ||||
|   for (const auto& pair : block_mapping) { | ||||
|     int64_t src_block_number = pair.first; | ||||
|     int64_t dst_block_number = pair.second; | ||||
|   const int64_t num_blocks = block_mapping.size(0); | ||||
|   for (size_t i = 0; i < num_blocks; i++) { | ||||
|     int64_t src_block_number = block_mapping[i][0].item<int64_t>(); | ||||
|     int64_t dst_block_number = block_mapping[i][1].item<int64_t>(); | ||||
|     int64_t src_offset = src_block_number * block_size_in_bytes; | ||||
|     int64_t dst_offset = dst_block_number * block_size_in_bytes; | ||||
|     cudaMemcpyAsync( | ||||
|       dst_ptr + dst_offset, | ||||
|       src_ptr + src_offset, | ||||
|       block_size_in_bytes, | ||||
|       memcpy_type, | ||||
|       stream); | ||||
|     cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, | ||||
|                     block_size_in_bytes, memcpy_type, stream); | ||||
|   } | ||||
| } | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| // Grid: (num_layers, num_pairs) | ||||
| template<typename scalar_t> | ||||
| __global__ void copy_blocks_kernel( | ||||
|   int64_t* key_cache_ptrs, | ||||
|   int64_t* value_cache_ptrs, | ||||
|   const int64_t* __restrict__ block_mapping, | ||||
|   const int numel_per_block) { | ||||
| template <typename scalar_t> | ||||
| __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, | ||||
|                                    int64_t* value_cache_ptrs, | ||||
|                                    const int64_t* __restrict__ block_mapping, | ||||
|                                    const int numel_per_block) { | ||||
|   const int layer_idx = blockIdx.x; | ||||
|   const int pair_idx = blockIdx.y; | ||||
|  | ||||
|   scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); | ||||
|   scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); | ||||
|   scalar_t* value_cache = | ||||
|       reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); | ||||
|   int64_t src_block_number = block_mapping[2 * pair_idx]; | ||||
|   int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; | ||||
|  | ||||
| @ -92,12 +93,11 @@ __global__ void copy_blocks_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| void copy_blocks( | ||||
|   std::vector<torch::Tensor>& key_caches, | ||||
|   std::vector<torch::Tensor>& value_caches, | ||||
|   const std::map<int64_t, std::vector<int64_t>>& block_mapping) { | ||||
| void copy_blocks(std::vector<torch::Tensor>& key_caches, | ||||
|                  std::vector<torch::Tensor>& value_caches, | ||||
|                  const torch::Tensor& block_mapping) { | ||||
|   int num_layers = key_caches.size(); | ||||
|   TORCH_CHECK(num_layers == value_caches.size()); | ||||
|   if (num_layers == 0) { | ||||
| @ -111,29 +111,23 @@ void copy_blocks( | ||||
|   int64_t key_cache_ptrs[num_layers]; | ||||
|   int64_t value_cache_ptrs[num_layers]; | ||||
|   for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { | ||||
|     key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); | ||||
|     value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); | ||||
|     key_cache_ptrs[layer_idx] = | ||||
|         reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); | ||||
|     value_cache_ptrs[layer_idx] = | ||||
|         reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); | ||||
|   } | ||||
|   // Create block mapping array. | ||||
|   std::vector<int64_t> block_mapping_vec; | ||||
|   for (const auto& pair : block_mapping) { | ||||
|     int64_t src_block_number = pair.first; | ||||
|     for (int64_t dst_block_number : pair.second) { | ||||
|       block_mapping_vec.push_back(src_block_number); | ||||
|       block_mapping_vec.push_back(dst_block_number); | ||||
|     } | ||||
|   } | ||||
|   int64_t* block_mapping_array = block_mapping_vec.data(); | ||||
|   int num_pairs = block_mapping_vec.size() / 2; | ||||
|  | ||||
|   // block_mapping is a 2D tensor with shape (num_pairs, 2). | ||||
|   int num_pairs = block_mapping.size(0); | ||||
|  | ||||
|   // Move the data structures to the GPU. | ||||
|   // NOTE: This synchronizes the CPU and GPU. | ||||
|   torch::Tensor key_cache_ptrs_tensor = torch::from_blob( | ||||
|     key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); | ||||
|   torch::Tensor value_cache_ptrs_tensor = torch::from_blob( | ||||
|     value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); | ||||
|   torch::Tensor block_mapping_tensor = torch::from_blob( | ||||
|     block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); | ||||
|   torch::Tensor key_cache_ptrs_tensor = | ||||
|       torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) | ||||
|           .to(cache_device); | ||||
|   torch::Tensor value_cache_ptrs_tensor = | ||||
|       torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) | ||||
|           .to(cache_device); | ||||
|  | ||||
|   // Launch the kernel. | ||||
|   const int numel_per_block = key_caches[0][0].numel(); | ||||
| @ -142,31 +136,28 @@ void copy_blocks( | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(cache_device); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( | ||||
|     key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { | ||||
|       vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|         key_cache_ptrs_tensor.data_ptr<int64_t>(), | ||||
|         value_cache_ptrs_tensor.data_ptr<int64_t>(), | ||||
|         block_mapping_tensor.data_ptr<int64_t>(), | ||||
|         numel_per_block); | ||||
|     })); | ||||
|       key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { | ||||
|         vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|             key_cache_ptrs_tensor.data_ptr<int64_t>(), | ||||
|             value_cache_ptrs_tensor.data_ptr<int64_t>(), | ||||
|             block_mapping.data_ptr<int64_t>(), numel_per_block); | ||||
|       })); | ||||
| } | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache> | ||||
| template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> | ||||
| __global__ void reshape_and_cache_kernel( | ||||
|   const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size] | ||||
|   const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size] | ||||
|   cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x] | ||||
|   cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size] | ||||
|   const int64_t* __restrict__ slot_mapping,   // [num_tokens] | ||||
|   const int key_stride, | ||||
|   const int value_stride, | ||||
|   const int num_heads, | ||||
|   const int head_size, | ||||
|   const int block_size, | ||||
|   const int x, | ||||
|   const float kv_scale) { | ||||
|     const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size] | ||||
|     const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size] | ||||
|     cache_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x, | ||||
|                                          // block_size, x] | ||||
|     cache_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, | ||||
|                                          // block_size] | ||||
|     const int64_t* __restrict__ slot_mapping,  // [num_tokens] | ||||
|     const int key_stride, const int value_stride, const int num_heads, | ||||
|     const int head_size, const int block_size, const int x, | ||||
|     const float kv_scale) { | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|   const int64_t slot_idx = slot_mapping[token_idx]; | ||||
|   if (slot_idx < 0) { | ||||
| @ -187,47 +178,39 @@ __global__ void reshape_and_cache_kernel( | ||||
|     const int x_idx = head_offset / x; | ||||
|     const int x_offset = head_offset % x; | ||||
|  | ||||
|     const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x | ||||
|                                 + head_idx * (head_size / x) * block_size * x | ||||
|                                 + x_idx * block_size * x | ||||
|                                 + block_offset * x | ||||
|                                 + x_offset; | ||||
|     const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size | ||||
|                                   + head_idx * head_size * block_size | ||||
|                                   + head_offset * block_size | ||||
|                                   + block_offset; | ||||
|     const int64_t tgt_key_idx = | ||||
|         block_idx * num_heads * (head_size / x) * block_size * x + | ||||
|         head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + | ||||
|         block_offset * x + x_offset; | ||||
|     const int64_t tgt_value_idx = | ||||
|         block_idx * num_heads * head_size * block_size + | ||||
|         head_idx * head_size * block_size + head_offset * block_size + | ||||
|         block_offset; | ||||
|     scalar_t tgt_key = key[src_key_idx]; | ||||
|     scalar_t tgt_value = value[src_value_idx]; | ||||
|     if constexpr (is_fp8_kv_cache) { | ||||
| #if defined(ENABLE_FP8_E5M2) | ||||
|       key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key); | ||||
|       value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value); | ||||
| #elif defined(ENABLE_FP8_E4M3) | ||||
|       key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale); | ||||
|       value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale); | ||||
| #else | ||||
|       assert(false); | ||||
| #endif | ||||
|     } else { | ||||
|     if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { | ||||
|       key_cache[tgt_key_idx] = tgt_key; | ||||
|       value_cache[tgt_value_idx] = tgt_value; | ||||
|     } else { | ||||
|       key_cache[tgt_key_idx] = | ||||
|           fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale); | ||||
|       value_cache[tgt_value_idx] = | ||||
|           fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| template <typename scalar_t> | ||||
| __global__ void reshape_and_cache_flash_kernel( | ||||
|   const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size] | ||||
|   const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size] | ||||
|   scalar_t* __restrict__ k_cache,             // [num_blocks, block_size, num_heads, head_size] | ||||
|   scalar_t* __restrict__ v_cache,             // [num_blocks, block_size, num_heads, head_size] | ||||
|   const int64_t* __restrict__ slot_mapping,   // [num_tokens] | ||||
|   const int block_stride, | ||||
|   const int key_stride, | ||||
|   const int value_stride, | ||||
|   const int num_heads, | ||||
|   const int head_size, | ||||
|   const int block_size) { | ||||
|     const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size] | ||||
|     const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size] | ||||
|     scalar_t* __restrict__ k_cache,      // [num_blocks, block_size, num_heads, | ||||
|                                          // head_size] | ||||
|     scalar_t* __restrict__ v_cache,      // [num_blocks, block_size, num_heads, | ||||
|                                          // head_size] | ||||
|     const int64_t* __restrict__ slot_mapping,  // [num_tokens] | ||||
|     const int block_stride, const int key_stride, const int value_stride, | ||||
|     const int num_heads, const int head_size, const int block_size) { | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|   const int64_t slot_idx = slot_mapping[token_idx]; | ||||
|   // NOTE: slot_idx can be -1 if the token is padded | ||||
| @ -242,40 +225,37 @@ __global__ void reshape_and_cache_flash_kernel( | ||||
|     const int64_t src_value_idx = token_idx * value_stride + i; | ||||
|     const int head_idx = i / head_size; | ||||
|     const int head_offset = i % head_size; | ||||
|     const int64_t tgt_value_idx = block_idx * block_stride | ||||
|                               + block_offset * num_heads * head_size | ||||
|                               + head_idx * head_size | ||||
|                               + head_offset; | ||||
|     const int64_t tgt_value_idx = block_idx * block_stride + | ||||
|                                   block_offset * num_heads * head_size + | ||||
|                                   head_idx * head_size + head_offset; | ||||
|     k_cache[tgt_value_idx] = key[src_key_idx]; | ||||
|     v_cache[tgt_value_idx] = value[src_value_idx]; | ||||
|   } | ||||
| } | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE)                                     \ | ||||
|   vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>(      \ | ||||
|     reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \ | ||||
|     reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \ | ||||
|     reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \ | ||||
|     reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \ | ||||
|     slot_mapping.data_ptr<int64_t>(),                                                              \ | ||||
|     key_stride,                                                                                    \ | ||||
|     value_stride,                                                                                  \ | ||||
|     num_heads,                                                                                     \ | ||||
|     head_size,                                                                                     \ | ||||
|     block_size,                                                                                    \ | ||||
|     x,                                                                                             \ | ||||
|     kv_scale); | ||||
| // KV_T is the stored data type of kv-cache. | ||||
| // CACHE_T is the data type of key and value tensors. | ||||
| // KV_DTYPE is the real data type of kv-cache. | ||||
| #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)               \ | ||||
|   vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE>             \ | ||||
|       <<<grid, block, 0, stream>>>(                                   \ | ||||
|           reinterpret_cast<KV_T*>(key.data_ptr()),                    \ | ||||
|           reinterpret_cast<KV_T*>(value.data_ptr()),                  \ | ||||
|           reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),           \ | ||||
|           reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),         \ | ||||
|           slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \ | ||||
|           num_heads, head_size, block_size, x, kv_scale); | ||||
|  | ||||
| void reshape_and_cache( | ||||
|   torch::Tensor& key,           // [num_tokens, num_heads, head_size] | ||||
|   torch::Tensor& value,         // [num_tokens, num_heads, head_size] | ||||
|   torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x] | ||||
|   torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size] | ||||
|   torch::Tensor& slot_mapping,  // [num_tokens] | ||||
|   const std::string& kv_cache_dtype, | ||||
|   const float kv_scale) | ||||
| { | ||||
|     torch::Tensor& key,    // [num_tokens, num_heads, head_size] | ||||
|     torch::Tensor& value,  // [num_tokens, num_heads, head_size] | ||||
|     torch::Tensor& | ||||
|         key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x] | ||||
|     torch::Tensor& | ||||
|         value_cache,  // [num_blocks, num_heads, head_size, block_size] | ||||
|     torch::Tensor& slot_mapping,  // [num_tokens] | ||||
|     const std::string& kv_cache_dtype, const float kv_scale) { | ||||
|   int num_tokens = key.size(0); | ||||
|   int num_heads = key.size(1); | ||||
|   int head_size = key.size(2); | ||||
| @ -289,35 +269,18 @@ void reshape_and_cache( | ||||
|   dim3 block(std::min(num_heads * head_size, 512)); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   if (kv_cache_dtype == "auto") { | ||||
|     if (key.dtype() == at::ScalarType::Float) { | ||||
|       CALL_RESHAPE_AND_CACHE(float, float, false); | ||||
|     } else if (key.dtype() == at::ScalarType::Half) { | ||||
|       CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); | ||||
|     } else if (key.dtype() == at::ScalarType::BFloat16) { | ||||
|       CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); | ||||
|     } | ||||
|   } else if (kv_cache_dtype == "fp8") { | ||||
|     if (key.dtype() == at::ScalarType::Float) { | ||||
|       CALL_RESHAPE_AND_CACHE(float, uint8_t, true); | ||||
|     } else if (key.dtype() == at::ScalarType::Half) { | ||||
|       CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); | ||||
|     } else if (key.dtype() == at::ScalarType::BFloat16) { | ||||
|       CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); | ||||
|     } | ||||
|   } else { | ||||
|     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); | ||||
|   } | ||||
|  | ||||
|   DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, | ||||
|                              CALL_RESHAPE_AND_CACHE) | ||||
| } | ||||
|  | ||||
| void reshape_and_cache_flash( | ||||
|   torch::Tensor& key,           // [num_tokens, num_heads, head_size] | ||||
|   torch::Tensor& value,         // [num_tokens, num_heads, head_size] | ||||
|   torch::Tensor& k_cache,       // [num_blocks, block_size, num_heads, head_size] | ||||
|   torch::Tensor& v_cache,       // [num_blocks, block_size, num_heads, head_size] | ||||
|   torch::Tensor& slot_mapping,  // [num_tokens] | ||||
|   const std::string& kv_cache_dtype) | ||||
| { | ||||
|     torch::Tensor& key,      // [num_tokens, num_heads, head_size] | ||||
|     torch::Tensor& value,    // [num_tokens, num_heads, head_size] | ||||
|     torch::Tensor& k_cache,  // [num_blocks, block_size, num_heads, head_size] | ||||
|     torch::Tensor& v_cache,  // [num_blocks, block_size, num_heads, head_size] | ||||
|     torch::Tensor& slot_mapping,  // [num_tokens] | ||||
|     const std::string& kv_cache_dtype) { | ||||
|   // FIXME: only support auto datatype, does not support fp8 | ||||
|   if (kv_cache_dtype != "auto") { | ||||
|     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); | ||||
| @ -337,63 +300,47 @@ void reshape_and_cache_flash( | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|     key.scalar_type(), | ||||
|     "reshape_and_cache_flash", | ||||
|     [&] { | ||||
|       vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|         key.data_ptr<scalar_t>(), | ||||
|         value.data_ptr<scalar_t>(), | ||||
|         k_cache.data_ptr<scalar_t>(), | ||||
|         v_cache.data_ptr<scalar_t>(), | ||||
|         slot_mapping.data_ptr<int64_t>(), | ||||
|         block_stride, | ||||
|         key_stride, | ||||
|         value_stride, | ||||
|         num_heads, | ||||
|         head_size, | ||||
|         block_size); | ||||
|     }); | ||||
|       key.scalar_type(), "reshape_and_cache_flash", [&] { | ||||
|         vllm::reshape_and_cache_flash_kernel<scalar_t> | ||||
|             <<<grid, block, 0, stream>>>( | ||||
|                 key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(), | ||||
|                 k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(), | ||||
|                 slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, | ||||
|                 value_stride, num_heads, head_size, block_size); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| template<typename Tout, typename Tin> | ||||
| __global__ void convert_fp8_kernel( | ||||
|   const Tin* __restrict__ src_cache, | ||||
|   Tout* __restrict__ dst_cache, | ||||
|   const int64_t block_stride) { | ||||
| template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> | ||||
| __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, | ||||
|                                    Tout* __restrict__ dst_cache, | ||||
|                                    const float kv_scale, | ||||
|                                    const int64_t block_stride) { | ||||
|   const int64_t block_idx = blockIdx.x; | ||||
|   for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { | ||||
|     int64_t idx = block_idx * block_stride + i; | ||||
| #if defined(ENABLE_FP8_E5M2) | ||||
|     dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]); | ||||
| #elif defined(ENABLE_FP8_E4M3) | ||||
|     dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]); | ||||
| #else | ||||
|     assert(false); | ||||
| #endif | ||||
|     dst_cache[idx] = | ||||
|         fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| #define CALL_CONVERT_FP8(Tout, Tin)                                 \ | ||||
|   vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \ | ||||
|     reinterpret_cast<Tin*>(src_cache.data_ptr()),                   \ | ||||
|     reinterpret_cast<Tout*>(dst_cache.data_ptr()),                  \ | ||||
|     block_stride); | ||||
| #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)                                \ | ||||
|   vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \ | ||||
|       reinterpret_cast<Tin*>(src_cache.data_ptr()),                          \ | ||||
|       reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride); | ||||
|  | ||||
| void convert_fp8( | ||||
|   torch::Tensor& src_cache, | ||||
|   torch::Tensor& dst_cache) | ||||
| { | ||||
| // Only for testing. | ||||
| void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, | ||||
|                  const float kv_scale, const std::string& kv_cache_dtype) { | ||||
|   torch::Device src_device = src_cache.device(); | ||||
|   torch::Device dst_device = dst_cache.device(); | ||||
|   TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") | ||||
|   TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") | ||||
|   TORCH_CHECK( | ||||
|     src_device.index() == dst_device.index(), | ||||
|     "src and dst must be on the same GPU"); | ||||
|   TORCH_CHECK(src_device.index() == dst_device.index(), | ||||
|               "src and dst must be on the same GPU"); | ||||
|   at::cuda::OptionalCUDAGuard device_guard(src_device); | ||||
|  | ||||
|   int64_t num_blocks = src_cache.size(0); | ||||
| @ -403,17 +350,37 @@ void convert_fp8( | ||||
|   dim3 block(std::min(block_stride, int64_t(512))); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
|   if (src_cache.dtype() == at::ScalarType::Float) { | ||||
|     CALL_CONVERT_FP8(uint8_t, float); | ||||
|   } else if (src_cache.dtype() == at::ScalarType::Half) { | ||||
|     CALL_CONVERT_FP8(uint8_t, uint16_t); | ||||
|   } else if (src_cache.dtype() == at::ScalarType::BFloat16) { | ||||
|     CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); | ||||
|   } else if (dst_cache.dtype() == at::ScalarType::Float) { | ||||
|     CALL_CONVERT_FP8(float, uint8_t); | ||||
|   } else if (dst_cache.dtype() == at::ScalarType::Half) { | ||||
|     CALL_CONVERT_FP8(uint16_t, uint8_t); | ||||
|   } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { | ||||
|     CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); | ||||
|   if (kv_cache_dtype == "auto") { | ||||
|     if (src_cache.dtype() == at::ScalarType::Float) { | ||||
|       CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); | ||||
|     } else if (src_cache.dtype() == at::ScalarType::Half) { | ||||
|       CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); | ||||
|     } else if (src_cache.dtype() == at::ScalarType::BFloat16) { | ||||
|       CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); | ||||
|     } else if (dst_cache.dtype() == at::ScalarType::Float) { | ||||
|       CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); | ||||
|     } else if (dst_cache.dtype() == at::ScalarType::Half) { | ||||
|       CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); | ||||
|     } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { | ||||
|       CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); | ||||
|     } | ||||
|   } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { | ||||
|     if (src_cache.dtype() == at::ScalarType::Float) { | ||||
|       CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); | ||||
|     } else if (src_cache.dtype() == at::ScalarType::Half) { | ||||
|       CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); | ||||
|     } else if (src_cache.dtype() == at::ScalarType::BFloat16) { | ||||
|       CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, | ||||
|                        vllm::Fp8KVCacheDataType::kFp8E4M3); | ||||
|     } else if (dst_cache.dtype() == at::ScalarType::Float) { | ||||
|       CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); | ||||
|     } else if (dst_cache.dtype() == at::ScalarType::Half) { | ||||
|       CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); | ||||
|     } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { | ||||
|       CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, | ||||
|                        vllm::Fp8KVCacheDataType::kFp8E4M3); | ||||
|     } | ||||
|   } else { | ||||
|     TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @ -1,10 +1,10 @@ | ||||
| #include "cpu_types.hpp" | ||||
|  | ||||
| namespace { | ||||
| template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &), | ||||
| template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&), | ||||
|           bool is_gated> | ||||
| void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, | ||||
|                        scalar_t *__restrict__ output) { | ||||
| void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, | ||||
|                        scalar_t* __restrict__ output) { | ||||
|   using scalar_vec_t = vec_op::vec_t<scalar_t>; | ||||
|   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); | ||||
|  | ||||
| @ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, | ||||
|   } | ||||
| } | ||||
|  | ||||
| FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { | ||||
| FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { | ||||
|   const vec_op::FP32Vec8 zeros(0.0); | ||||
|   const vec_op::FP32Vec8 ones(1.0); | ||||
|   return x / (ones + (zeros - x).exp()); | ||||
| } | ||||
|  | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { | ||||
|   const vec_op::FP32Vec8 ones(1.0); | ||||
|   const vec_op::FP32Vec8 w1(0.79788456f); | ||||
|   const vec_op::FP32Vec8 w2(0.044715f); | ||||
| @ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { | ||||
|   return w3 * x * (ones + t); | ||||
| } | ||||
|  | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { | ||||
|   const vec_op::FP32Vec8 ones(1.0); | ||||
|   const vec_op::FP32Vec8 w1(0.79788456f); | ||||
|   const vec_op::FP32Vec8 w2(0.044715f); | ||||
| @ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { | ||||
|   return w3 * x * (ones + t); | ||||
| } | ||||
|  | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { | ||||
|   const vec_op::FP32Vec8 ones(1.0); | ||||
|   const vec_op::FP32Vec8 w1(M_SQRT1_2); | ||||
|   const vec_op::FP32Vec8 w2(0.5); | ||||
|   return x * w2 * (ones + (x * w1).er()); | ||||
| } | ||||
|  | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { | ||||
| FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { | ||||
|   const vec_op::FP32Vec8 ones(1.0); | ||||
|   const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); | ||||
|   const vec_op::FP32Vec8 w2(0.5); | ||||
| @ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { | ||||
|   const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); | ||||
|   return x * w2 * (ones + inner.tanh()); | ||||
| } | ||||
| }; // namespace | ||||
| };  // namespace | ||||
|  | ||||
| void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { | ||||
| void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { | ||||
|   int num_tokens = input.numel() / input.size(-1); | ||||
|   int d = input.size(-1) / 2; | ||||
|  | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "silu_and_mul_impl", [&] { | ||||
|         CPU_KERNEL_GUARD_IN(silu_and_mul_impl) | ||||
|         activation_kernel<scalar_t, silu_act, true>(num_tokens, d, | ||||
|                                                     input.data_ptr<scalar_t>(), | ||||
|                                                     out.data_ptr<scalar_t>()); | ||||
|         CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) | ||||
|       }); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { | ||||
|     CPU_KERNEL_GUARD_IN(silu_and_mul_impl) | ||||
|     activation_kernel<scalar_t, silu_act, true>( | ||||
|         num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>()); | ||||
|     CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void gelu_and_mul(torch::Tensor &out,   // [..., d] | ||||
|                       torch::Tensor &input) // [..., 2 * d] | ||||
| void gelu_and_mul(torch::Tensor& out,    // [..., d] | ||||
|                   torch::Tensor& input)  // [..., 2 * d] | ||||
| { | ||||
|   int num_tokens = input.numel() / input.size(-1); | ||||
|   int d = input.size(-1) / 2; | ||||
|  | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "gelu_and_mul_impl", [&] { | ||||
|         CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) | ||||
|         activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, | ||||
|                                                     input.data_ptr<scalar_t>(), | ||||
|                                                     out.data_ptr<scalar_t>()); | ||||
|         CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) | ||||
|       }); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { | ||||
|     CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) | ||||
|     activation_kernel<scalar_t, gelu_act, true>( | ||||
|         num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>()); | ||||
|     CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void gelu_tanh_and_mul(torch::Tensor &out,   // [..., d] | ||||
|                            torch::Tensor &input) // [..., 2 * d] | ||||
| void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d] | ||||
|                        torch::Tensor& input)  // [..., 2 * d] | ||||
| { | ||||
|   int num_tokens = input.numel() / input.size(-1); | ||||
|   int d = input.size(-1) / 2; | ||||
| @ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out,   // [..., d] | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void gelu_new(torch::Tensor &out, torch::Tensor &input) { | ||||
| void gelu_new(torch::Tensor& out, torch::Tensor& input) { | ||||
|   int num_tokens = input.numel() / input.size(-1); | ||||
|   int d = input.size(-1); | ||||
|  | ||||
| @ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void gelu_fast(torch::Tensor &out, torch::Tensor &input) { | ||||
| void gelu_fast(torch::Tensor& out, torch::Tensor& input) { | ||||
|   int num_tokens = input.numel() / input.size(-1); | ||||
|   int d = input.size(-1); | ||||
|  | ||||
|  | ||||
| @ -2,7 +2,8 @@ | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename scalar_t> struct KernelVecType { | ||||
| template <typename scalar_t> | ||||
| struct KernelVecType { | ||||
|   using q_load_vec_type = void; | ||||
|   using q_vec_type = void; | ||||
|   using k_load_vec_type = void; | ||||
| @ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType { | ||||
|   using v_load_vec_type = void; | ||||
| }; | ||||
|  | ||||
| template <> struct KernelVecType<float> { | ||||
| template <> | ||||
| struct KernelVecType<float> { | ||||
|   using q_load_vec_type = vec_op::FP32Vec4; | ||||
|   using q_vec_type = vec_op::FP32Vec16; | ||||
|   using k_load_vec_type = vec_op::FP32Vec16; | ||||
| @ -21,7 +23,8 @@ template <> struct KernelVecType<float> { | ||||
| }; | ||||
|  | ||||
| #ifdef __AVX512BF16__ | ||||
| template <> struct KernelVecType<c10::BFloat16> { | ||||
| template <> | ||||
| struct KernelVecType<c10::BFloat16> { | ||||
|   using q_load_vec_type = vec_op::BF16Vec8; | ||||
|   using q_vec_type = vec_op::BF16Vec32; | ||||
|   using k_load_vec_type = vec_op::BF16Vec32; | ||||
| @ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> { | ||||
|   using v_load_vec_type = vec_op::BF16Vec16; | ||||
| }; | ||||
| #else | ||||
| template <> struct KernelVecType<c10::BFloat16> { | ||||
| template <> | ||||
| struct KernelVecType<c10::BFloat16> { | ||||
|   using q_load_vec_type = vec_op::BF16Vec8; | ||||
|   using q_vec_type = vec_op::FP32Vec16; | ||||
|   using k_load_vec_type = vec_op::BF16Vec16; | ||||
| @ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> { | ||||
| #endif | ||||
|  | ||||
| template <typename T> | ||||
| FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, | ||||
| FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size, | ||||
|                                            const int capacity) { | ||||
|   T max = data[0]; | ||||
|   for (int i = 1; i < size; ++i) { | ||||
| @ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| FORCE_INLINE std::pair<T, T> | ||||
| reduceSoftmaxAlibi(T *data, const int size, const int capacity, | ||||
|                    const float alibi_slope, const int start_index, | ||||
|                    const int seq_len) { | ||||
| FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size, | ||||
|                                                 const int capacity, | ||||
|                                                 const float alibi_slope, | ||||
|                                                 const int start_index, | ||||
|                                                 const int seq_len) { | ||||
|   data[0] += alibi_slope * (start_index - seq_len + 1); | ||||
|   T max = data[0]; | ||||
|   for (int i = 1; i < size; ++i) { | ||||
| @ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, | ||||
| FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, | ||||
|                                         const int size) { | ||||
|   T max = max_data[0]; | ||||
|   for (int i = 1; i < size; ++i) { | ||||
| @ -132,9 +137,9 @@ struct reduceQKBlockKernel { | ||||
|   static_assert(k_load_vec_type::get_elem_num() % x == 0); | ||||
|   static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); | ||||
|  | ||||
|   FORCE_INLINE static void call(const scalar_t *__restrict__ q, | ||||
|                                 const scalar_t *__restrict__ k_block, | ||||
|                                 float *__restrict__ logits, float scale, | ||||
|   FORCE_INLINE static void call(const scalar_t* __restrict__ q, | ||||
|                                 const scalar_t* __restrict__ k_block, | ||||
|                                 float* __restrict__ logits, float scale, | ||||
|                                 const int token_num) { | ||||
|     const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; | ||||
|  | ||||
| @ -196,8 +201,8 @@ struct reduceQKBlockKernel { | ||||
|  | ||||
| template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, | ||||
|           int HEAD_PARTITION_SIZE, typename acc_t> | ||||
| FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, | ||||
|                                    acc_t &&acc) { | ||||
| FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, | ||||
|                                    acc_t&& acc) { | ||||
|   using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; | ||||
|   constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); | ||||
|   static_assert(BLOCK_SIZE == ELEM_NUM); | ||||
| @ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, | ||||
|     acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; | ||||
|   }); | ||||
| } | ||||
| }; // namespace | ||||
| };  // namespace | ||||
|  | ||||
| // Paged attention v1 | ||||
| namespace { | ||||
| template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE> | ||||
| struct paged_attention_v1_impl { | ||||
|   static void | ||||
|   call(scalar_t *__restrict__ out,           // [num_seqs, num_heads, head_size] | ||||
|        const scalar_t *__restrict__ q,       // [num_seqs, num_heads, head_size] | ||||
|        const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, | ||||
|   static void call( | ||||
|       scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size] | ||||
|       const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size] | ||||
|       const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads, | ||||
|                                              // head_size/x, block_size, x] | ||||
|        const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, | ||||
|       const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads, | ||||
|                                              // head_size, block_size] | ||||
|        const int num_kv_heads, const float scale, | ||||
|        const int | ||||
|            *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] | ||||
|        const int *__restrict__ seq_lens, // [num_seqs] | ||||
|        const int max_num_blocks_per_seq, | ||||
|        const float *__restrict__ alibi_slopes, // [num_heads] | ||||
|        const int q_stride, const int kv_block_stride, const int kv_head_stride, | ||||
|        const int num_seqs, const int num_heads) { | ||||
|       const int num_kv_heads, const float scale, | ||||
|       const int* __restrict__ block_tables,  // [num_seqs, | ||||
|                                              // max_num_blocks_per_seq] | ||||
|       const int* __restrict__ seq_lens,      // [num_seqs] | ||||
|       const int max_num_blocks_per_seq, | ||||
|       const float* __restrict__ alibi_slopes,  // [num_heads] | ||||
|       const int q_stride, const int kv_block_stride, const int kv_head_stride, | ||||
|       const int num_seqs, const int num_heads) { | ||||
|     constexpr int x = 16 / sizeof(scalar_t); | ||||
|     const int num_queries_per_kv = num_heads / num_kv_heads; | ||||
|  | ||||
| @ -243,32 +248,31 @@ struct paged_attention_v1_impl { | ||||
|  | ||||
|     size_t logits_bytes = | ||||
|         parallel_work_item_num * max_seq_len_padded * sizeof(float); | ||||
|     float *logits = (float *)std::aligned_alloc( | ||||
|         64, logits_bytes); // Cacheline alignment for each context token. | ||||
|                            // [parallel_work_item_num, max_seq_len_padded] | ||||
|     float* logits = (float*)std::aligned_alloc( | ||||
|         64, logits_bytes);  // Cacheline alignment for each context token. | ||||
|                             // [parallel_work_item_num, max_seq_len_padded] | ||||
|  | ||||
| #pragma omp parallel for collapse(2) schedule(dynamic, 1) | ||||
|     for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { | ||||
|       for (int head_idx = 0; head_idx < num_heads; ++head_idx) { | ||||
|         int seq_len = seq_lens[seq_idx]; | ||||
|         const int *seq_block_table = | ||||
|         const int* seq_block_table = | ||||
|             block_tables + max_num_blocks_per_seq * seq_idx; | ||||
|         const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; | ||||
|         const int64_t kv_head_idx = head_idx / num_queries_per_kv; | ||||
|         const scalar_t *__restrict__ q_vec_ptr = | ||||
|         const scalar_t* __restrict__ q_vec_ptr = | ||||
|             q + seq_idx * q_stride + head_idx * HEAD_SIZE; | ||||
|         const int last_block_token_num = | ||||
|             seq_len - (block_num - 1) * BLOCK_SIZE; | ||||
|         float *__restrict__ thread_block_logits = | ||||
|         const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; | ||||
|         float* __restrict__ thread_block_logits = | ||||
|             logits + omp_get_thread_num() * max_seq_len_padded; | ||||
|  | ||||
|         // Compute logits | ||||
|         for (int block_idx = 0; block_idx < block_num; ++block_idx) { | ||||
|           const int64_t physical_block_idx = seq_block_table[block_idx]; | ||||
|           const scalar_t *__restrict__ k_block_cache_ptr = | ||||
|           const scalar_t* __restrict__ k_block_cache_ptr = | ||||
|               k_cache + physical_block_idx * kv_block_stride + | ||||
|               kv_head_idx * kv_head_stride; | ||||
|           float *__restrict__ head_block_logits = | ||||
|           float* __restrict__ head_block_logits = | ||||
|               thread_block_logits + block_idx * BLOCK_SIZE; | ||||
|  | ||||
|           reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( | ||||
| @ -282,8 +286,7 @@ struct paged_attention_v1_impl { | ||||
|                              block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, | ||||
|                              seq_len); | ||||
|         } else { | ||||
|           reduceSoftmax(thread_block_logits, seq_len, | ||||
|                         block_num * BLOCK_SIZE); | ||||
|           reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); | ||||
|         } | ||||
|  | ||||
|         // Compute value | ||||
| @ -293,14 +296,14 @@ struct paged_attention_v1_impl { | ||||
|         for (int head_part_idx = 0; head_part_idx < head_partition_num; | ||||
|              ++head_part_idx) { | ||||
|           vec_op::FP32Vec16 accums[head_elem_num_per_partition]; | ||||
|           scalar_t *__restrict__ out_ptr = | ||||
|           scalar_t* __restrict__ out_ptr = | ||||
|               out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + | ||||
|               head_part_idx * head_elem_num_per_partition; | ||||
|           for (int block_idx = 0; block_idx < block_num; ++block_idx) { | ||||
|             const int64_t physical_block_idx = seq_block_table[block_idx]; | ||||
|             const float *__restrict__ prob_vec_ptr = | ||||
|             const float* __restrict__ prob_vec_ptr = | ||||
|                 thread_block_logits + block_idx * BLOCK_SIZE; | ||||
|             const scalar_t *__restrict__ v_block_cache_ptr = | ||||
|             const scalar_t* __restrict__ v_block_cache_ptr = | ||||
|                 v_cache + physical_block_idx * kv_block_stride + | ||||
|                 kv_head_idx * kv_head_stride + | ||||
|                 BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; | ||||
| @ -311,7 +314,7 @@ struct paged_attention_v1_impl { | ||||
|             if (block_idx != block_num - 1) { | ||||
|               const int64_t next_physical_block_idx = | ||||
|                   seq_block_table[block_idx + 1]; | ||||
|               const scalar_t *__restrict__ next_v_block_cache_ptr = | ||||
|               const scalar_t* __restrict__ next_v_block_cache_ptr = | ||||
|                   v_cache + next_physical_block_idx * kv_block_stride + | ||||
|                   kv_head_idx * kv_head_stride + | ||||
|                   BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; | ||||
| @ -340,16 +343,16 @@ struct paged_attention_v1_impl { | ||||
| #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE)                   \ | ||||
|   paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call(                     \ | ||||
|       out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ | ||||
|       block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,              \ | ||||
|       block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,                  \ | ||||
|       alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs,   \ | ||||
|       num_heads); | ||||
|  | ||||
| template <typename T, int BLOCK_SIZE> | ||||
| void paged_attention_v1_impl_launcher( | ||||
|     torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, | ||||
|     torch::Tensor &value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor &block_tables, torch::Tensor &seq_lens, | ||||
|     int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { | ||||
|     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, | ||||
|     torch::Tensor& value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, | ||||
|     const c10::optional<torch::Tensor>& alibi_slopes) { | ||||
|   int num_seqs = query.size(0); | ||||
|   int num_heads = query.size(1); | ||||
|   int head_size = query.size(2); | ||||
| @ -359,68 +362,73 @@ void paged_attention_v1_impl_launcher( | ||||
|   int kv_head_stride = key_cache.stride(1); | ||||
|  | ||||
|   // NOTE: alibi_slopes is optional. | ||||
|   const float *alibi_slopes_ptr = | ||||
|   const float* alibi_slopes_ptr = | ||||
|       alibi_slopes | ||||
|           ? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) | ||||
|           ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) | ||||
|           : nullptr; | ||||
|  | ||||
|   T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); | ||||
|   T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); | ||||
|   T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); | ||||
|   T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); | ||||
|   int *block_tables_ptr = block_tables.data_ptr<int>(); | ||||
|   int *seq_lens_ptr = seq_lens.data_ptr<int>(); | ||||
|   T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); | ||||
|   T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); | ||||
|   T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); | ||||
|   T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); | ||||
|   int* block_tables_ptr = block_tables.data_ptr<int>(); | ||||
|   int* seq_lens_ptr = seq_lens.data_ptr<int>(); | ||||
|  | ||||
|   switch (head_size) { | ||||
|   case 64: | ||||
|     LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 80: | ||||
|     LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 96: | ||||
|     LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 112: | ||||
|     LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 128: | ||||
|     LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 256: | ||||
|     LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); | ||||
|     break; | ||||
|   default: | ||||
|     TORCH_CHECK(false, "Unsupported head size: ", head_size); | ||||
|     break; | ||||
|     case 64: | ||||
|       LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 80: | ||||
|       LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 96: | ||||
|       LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 112: | ||||
|       LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 128: | ||||
|       LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 192: | ||||
|       LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 256: | ||||
|       LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); | ||||
|       break; | ||||
|     default: | ||||
|       TORCH_CHECK(false, "Unsupported head size: ", head_size); | ||||
|       break; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE)                                 \ | ||||
|   paged_attention_v1_impl_launcher<T, BLOCK_SIZE>(                             \ | ||||
|       out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,   \ | ||||
| #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE)                               \ | ||||
|   paged_attention_v1_impl_launcher<T, BLOCK_SIZE>(                           \ | ||||
|       out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ | ||||
|       seq_lens, max_seq_len, alibi_slopes); | ||||
|  | ||||
| #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T)                                  \ | ||||
|   switch (block_size) {                                                        \ | ||||
|   case 16:                                                                     \ | ||||
|     CALL_V1_KERNEL_LAUNCHER(T, 16);                                            \ | ||||
|     break;                                                                     \ | ||||
|   default:                                                                     \ | ||||
|     TORCH_CHECK(false, "Unsupported block size: ", block_size);                \ | ||||
|     break;                                                                     \ | ||||
| #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T)                     \ | ||||
|   switch (block_size) {                                           \ | ||||
|     case 16:                                                      \ | ||||
|       CALL_V1_KERNEL_LAUNCHER(T, 16);                             \ | ||||
|       break;                                                      \ | ||||
|     default:                                                      \ | ||||
|       TORCH_CHECK(false, "Unsupported block size: ", block_size); \ | ||||
|       break;                                                      \ | ||||
|   } | ||||
| } // namespace | ||||
| }  // namespace | ||||
|  | ||||
| void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, | ||||
|                         torch::Tensor &key_cache, torch::Tensor &value_cache, | ||||
|                         int num_kv_heads, float scale, | ||||
|                         torch::Tensor &block_tables, | ||||
|                         torch::Tensor &seq_lens, int block_size, | ||||
|                         int max_seq_len, | ||||
|                         const c10::optional<torch::Tensor> &alibi_slopes, | ||||
|                         const std::string &kv_cache_dtype, float kv_scale) { | ||||
| void paged_attention_v1( | ||||
|     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, | ||||
|     torch::Tensor& value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, | ||||
|     int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, | ||||
|     const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, | ||||
|     const int blocksparse_local_blocks, const int blocksparse_vert_stride, | ||||
|     const int blocksparse_block_size, const int blocksparse_head_sliding_step) { | ||||
|   TORCH_CHECK(kv_scale == 1.0f); | ||||
|   TORCH_CHECK(blocksparse_vert_stride <= 1, | ||||
|               "CPU backend does not support blocksparse attention yet."); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", | ||||
|                                [&] { | ||||
|                                  CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) | ||||
| @ -434,23 +442,24 @@ namespace { | ||||
| template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE> | ||||
| struct paged_attention_v2_impl { | ||||
|   static void call( | ||||
|       scalar_t *__restrict__ out,   // [num_seqs, num_heads, head_size] | ||||
|       float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] | ||||
|       float | ||||
|           *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] | ||||
|       scalar_t *__restrict__ tmp_out,       // [num_seqs, num_heads, | ||||
|                                             // max_num_partitions, head_size] | ||||
|       const scalar_t *__restrict__ q,       // [num_seqs, num_heads, head_size] | ||||
|       const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, | ||||
|                                             // head_size/x, block_size, x] | ||||
|       const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, | ||||
|                                             // head_size, block_size] | ||||
|       scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size] | ||||
|       float* __restrict__ exp_sums,          // [num_seqs, num_heads, | ||||
|                                              // max_num_partitions] | ||||
|       float* __restrict__ max_logits,        // [num_seqs, num_heads, | ||||
|                                              // max_num_partitions] | ||||
|       scalar_t* __restrict__ tmp_out,        // [num_seqs, num_heads, | ||||
|                                              // max_num_partitions, head_size] | ||||
|       const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size] | ||||
|       const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads, | ||||
|                                              // head_size/x, block_size, x] | ||||
|       const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads, | ||||
|                                              // head_size, block_size] | ||||
|       const int num_kv_heads, const float scale, | ||||
|       const int | ||||
|           *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] | ||||
|       const int *__restrict__ seq_lens, // [num_seqs] | ||||
|       const int* __restrict__ block_tables,  // [num_seqs, | ||||
|                                              // max_num_blocks_per_seq] | ||||
|       const int* __restrict__ seq_lens,      // [num_seqs] | ||||
|       const int max_num_blocks_per_seq, | ||||
|       const float *__restrict__ alibi_slopes, // [num_heads] | ||||
|       const float* __restrict__ alibi_slopes,  // [num_heads] | ||||
|       const int q_stride, const int kv_block_stride, const int kv_head_stride, | ||||
|       const int num_seqs, const int num_heads, const int max_num_partitions) { | ||||
|     constexpr int x = 16 / sizeof(scalar_t); | ||||
| @ -468,8 +477,7 @@ struct paged_attention_v2_impl { | ||||
|           const int seq_len = seq_lens[seq_idx]; | ||||
|           const int start_token_idx = partition_idx * PARTITION_SIZE; | ||||
|  | ||||
|           if (start_token_idx >= seq_len) | ||||
|             continue; | ||||
|           if (start_token_idx >= seq_len) continue; | ||||
|  | ||||
|           const int partition_num = | ||||
|               (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; | ||||
| @ -477,15 +485,14 @@ struct paged_attention_v2_impl { | ||||
|           const int token_num = | ||||
|               (std::min(seq_len, start_token_idx + PARTITION_SIZE) - | ||||
|                start_token_idx); | ||||
|           const int block_num = | ||||
|               (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; | ||||
|           const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; | ||||
|           const int last_block_token_num = | ||||
|               token_num - (block_num - 1) * BLOCK_SIZE; | ||||
|           const int *seq_block_table = block_tables + | ||||
|           const int* seq_block_table = block_tables + | ||||
|                                        max_num_blocks_per_seq * seq_idx + | ||||
|                                        start_token_idx / BLOCK_SIZE; | ||||
|           const int64_t kv_head_idx = head_idx / num_queries_per_kv; | ||||
|           const scalar_t *__restrict__ q_vec_ptr = | ||||
|           const scalar_t* __restrict__ q_vec_ptr = | ||||
|               q + seq_idx * q_stride + head_idx * HEAD_SIZE; | ||||
|  | ||||
|           float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; | ||||
| @ -493,10 +500,10 @@ struct paged_attention_v2_impl { | ||||
|           // Compute logits | ||||
|           for (int block_idx = 0; block_idx < block_num; ++block_idx) { | ||||
|             const int64_t physical_block_idx = seq_block_table[block_idx]; | ||||
|             const scalar_t *__restrict__ k_block_cache_ptr = | ||||
|             const scalar_t* __restrict__ k_block_cache_ptr = | ||||
|                 k_cache + physical_block_idx * kv_block_stride + | ||||
|                 kv_head_idx * kv_head_stride; | ||||
|             float *__restrict__ head_block_logits = | ||||
|             float* __restrict__ head_block_logits = | ||||
|                 logits + block_idx * BLOCK_SIZE; | ||||
|  | ||||
|             reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( | ||||
| @ -510,13 +517,13 @@ struct paged_attention_v2_impl { | ||||
|                 logits, token_num, block_num * BLOCK_SIZE, | ||||
|                 alibi_slopes[head_idx], start_token_idx, seq_len); | ||||
|           } else { | ||||
|             max_and_sum = reduceSoftmax(logits, token_num, | ||||
|                                         block_num * BLOCK_SIZE); | ||||
|             max_and_sum = | ||||
|                 reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); | ||||
|           } | ||||
|  | ||||
|           auto &&[max_logit, exp_sum] = max_and_sum; | ||||
|           auto&& [max_logit, exp_sum] = max_and_sum; | ||||
|  | ||||
|           scalar_t *__restrict__ output_buffer = nullptr; | ||||
|           scalar_t* __restrict__ output_buffer = nullptr; | ||||
|           if (!no_reduce) { | ||||
|             auto idx = seq_idx * num_heads * max_num_partitions + | ||||
|                        head_idx * max_num_partitions + partition_idx; | ||||
| @ -538,13 +545,13 @@ struct paged_attention_v2_impl { | ||||
|           for (int head_part_idx = 0; head_part_idx < head_partition_num; | ||||
|                ++head_part_idx) { | ||||
|             vec_op::FP32Vec16 accums[head_elem_num_per_partition]; | ||||
|             scalar_t *__restrict__ out_ptr = | ||||
|             scalar_t* __restrict__ out_ptr = | ||||
|                 output_buffer + head_part_idx * head_elem_num_per_partition; | ||||
|             for (int block_idx = 0; block_idx < block_num; ++block_idx) { | ||||
|               const int64_t physical_block_idx = seq_block_table[block_idx]; | ||||
|               const float *__restrict__ prob_vec_ptr = | ||||
|               const float* __restrict__ prob_vec_ptr = | ||||
|                   logits + block_idx * BLOCK_SIZE; | ||||
|               const scalar_t *__restrict__ v_block_cache_ptr = | ||||
|               const scalar_t* __restrict__ v_block_cache_ptr = | ||||
|                   v_cache + physical_block_idx * kv_block_stride + | ||||
|                   kv_head_idx * kv_head_stride + | ||||
|                   BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; | ||||
| @ -555,7 +562,7 @@ struct paged_attention_v2_impl { | ||||
|               if (block_idx != block_num - 1) { | ||||
|                 const int64_t next_physical_block_idx = | ||||
|                     seq_block_table[block_idx + 1]; | ||||
|                 const scalar_t *__restrict__ next_v_block_cache_ptr = | ||||
|                 const scalar_t* __restrict__ next_v_block_cache_ptr = | ||||
|                     v_cache + next_physical_block_idx * kv_block_stride + | ||||
|                     kv_head_idx * kv_head_stride + | ||||
|                     BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; | ||||
| @ -587,8 +594,7 @@ struct paged_attention_v2_impl { | ||||
|         const int partition_num = | ||||
|             (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; | ||||
|  | ||||
|         if (partition_num == 1) | ||||
|           continue; | ||||
|         if (partition_num == 1) continue; | ||||
|  | ||||
|         reducePartitonSoftmax( | ||||
|             max_logits + seq_idx * num_heads * max_num_partitions + | ||||
| @ -603,11 +609,11 @@ struct paged_attention_v2_impl { | ||||
|     using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; | ||||
|     static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); | ||||
|     constexpr int head_elem_num_per_group = | ||||
|         16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE | ||||
|             // didn't align with 64 bytes | ||||
|         16;  // Note: didn't align with the cacheline size, due to some | ||||
|              // HEAD_SIZE didn't align with 64 bytes | ||||
|     static_assert(HEAD_SIZE % head_elem_num_per_group == 0); | ||||
|     constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; | ||||
|     const float *__restrict__ rescale_factors = exp_sums; | ||||
|     const float* __restrict__ rescale_factors = exp_sums; | ||||
| #pragma omp parallel for collapse(3) schedule(static, 1) | ||||
|     for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { | ||||
|       for (int head_idx = 0; head_idx < num_heads; ++head_idx) { | ||||
| @ -616,17 +622,16 @@ struct paged_attention_v2_impl { | ||||
|           const int partition_num = | ||||
|               (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; | ||||
|  | ||||
|           if (partition_num == 1) | ||||
|             continue; | ||||
|           if (partition_num == 1) continue; | ||||
|  | ||||
|           const float *__restrict__ seq_head_rescale_factors = | ||||
|           const float* __restrict__ seq_head_rescale_factors = | ||||
|               rescale_factors + seq_idx * num_heads * max_num_partitions + | ||||
|               head_idx * max_num_partitions; | ||||
|           const scalar_t *__restrict__ seq_head_tmp_out = | ||||
|           const scalar_t* __restrict__ seq_head_tmp_out = | ||||
|               tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + | ||||
|               head_idx * max_num_partitions * HEAD_SIZE + | ||||
|               group_idx * head_elem_num_per_group; | ||||
|           scalar_t *__restrict__ seq_head_output = | ||||
|           scalar_t* __restrict__ seq_head_output = | ||||
|               out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + | ||||
|               group_idx * head_elem_num_per_group; | ||||
|  | ||||
| @ -645,21 +650,21 @@ struct paged_attention_v2_impl { | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE)                   \ | ||||
|   paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call(     \ | ||||
|       out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr,           \ | ||||
|       key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr,   \ | ||||
|       seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,    \ | ||||
|       kv_block_stride, kv_head_stride, num_seqs, num_heads,                    \ | ||||
| #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE)                 \ | ||||
|   paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call(   \ | ||||
|       out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr,         \ | ||||
|       key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ | ||||
|       seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,      \ | ||||
|       kv_block_stride, kv_head_stride, num_seqs, num_heads,                  \ | ||||
|       max_num_partitions); | ||||
|  | ||||
| template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512> | ||||
| void paged_attention_v2_impl_launcher( | ||||
|     torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, | ||||
|     torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, | ||||
|     torch::Tensor &value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, | ||||
|     int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { | ||||
|     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, | ||||
|     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, | ||||
|     torch::Tensor& value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, | ||||
|     int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) { | ||||
|   int num_seqs = query.size(0); | ||||
|   int num_heads = query.size(1); | ||||
|   int head_size = query.size(2); | ||||
| @ -670,73 +675,78 @@ void paged_attention_v2_impl_launcher( | ||||
|   int max_num_partitions = exp_sums.size(-1); | ||||
|  | ||||
|   // NOTE: alibi_slopes is optional. | ||||
|   const float *alibi_slopes_ptr = | ||||
|   const float* alibi_slopes_ptr = | ||||
|       alibi_slopes | ||||
|           ? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) | ||||
|           ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) | ||||
|           : nullptr; | ||||
|  | ||||
|   T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); | ||||
|   float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr()); | ||||
|   float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr()); | ||||
|   T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr()); | ||||
|   T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); | ||||
|   T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); | ||||
|   T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); | ||||
|   int *block_tables_ptr = block_tables.data_ptr<int>(); | ||||
|   int *seq_lens_ptr = seq_lens.data_ptr<int>(); | ||||
|   T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); | ||||
|   float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); | ||||
|   float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); | ||||
|   T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); | ||||
|   T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); | ||||
|   T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); | ||||
|   T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); | ||||
|   int* block_tables_ptr = block_tables.data_ptr<int>(); | ||||
|   int* seq_lens_ptr = seq_lens.data_ptr<int>(); | ||||
|  | ||||
|   switch (head_size) { | ||||
|   case 64: | ||||
|     LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 80: | ||||
|     LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 96: | ||||
|     LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 112: | ||||
|     LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 128: | ||||
|     LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); | ||||
|     break; | ||||
|   case 256: | ||||
|     LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); | ||||
|     break; | ||||
|   default: | ||||
|     TORCH_CHECK(false, "Unsupported head size: ", head_size); | ||||
|     break; | ||||
|     case 64: | ||||
|       LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 80: | ||||
|       LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 96: | ||||
|       LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 112: | ||||
|       LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 128: | ||||
|       LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 192: | ||||
|       LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); | ||||
|       break; | ||||
|     case 256: | ||||
|       LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); | ||||
|       break; | ||||
|     default: | ||||
|       TORCH_CHECK(false, "Unsupported head size: ", head_size); | ||||
|       break; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE)                                 \ | ||||
|   paged_attention_v2_impl_launcher<T, BLOCK_SIZE>(                             \ | ||||
|       out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,       \ | ||||
|       num_kv_heads, scale, block_tables, seq_lens, block_size,             \ | ||||
|       max_seq_len, alibi_slopes); | ||||
| #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE)                              \ | ||||
|   paged_attention_v2_impl_launcher<T, BLOCK_SIZE>(                          \ | ||||
|       out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,    \ | ||||
|       num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ | ||||
|       alibi_slopes); | ||||
|  | ||||
| #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T)                                  \ | ||||
|   switch (block_size) {                                                        \ | ||||
|   case 16:                                                                     \ | ||||
|     CALL_V2_KERNEL_LAUNCHER(T, 16);                                            \ | ||||
|     break;                                                                     \ | ||||
|   default:                                                                     \ | ||||
|     TORCH_CHECK(false, "Unsupported block size: ", block_size);                \ | ||||
|     break;                                                                     \ | ||||
| #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T)                     \ | ||||
|   switch (block_size) {                                           \ | ||||
|     case 16:                                                      \ | ||||
|       CALL_V2_KERNEL_LAUNCHER(T, 16);                             \ | ||||
|       break;                                                      \ | ||||
|     default:                                                      \ | ||||
|       TORCH_CHECK(false, "Unsupported block size: ", block_size); \ | ||||
|       break;                                                      \ | ||||
|   } | ||||
| } // namespace | ||||
| }  // namespace | ||||
|  | ||||
| void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, | ||||
|                         torch::Tensor &max_logits, torch::Tensor &tmp_out, | ||||
|                         torch::Tensor &query, torch::Tensor &key_cache, | ||||
|                         torch::Tensor &value_cache, int num_kv_heads, | ||||
|                         float scale, torch::Tensor &block_tables, | ||||
|                         torch::Tensor &seq_lens, int block_size, | ||||
|                         int max_seq_len, | ||||
|                         const c10::optional<torch::Tensor> &alibi_slopes, | ||||
|                         const std::string &kv_cache_dtype, float kv_scale) { | ||||
| void paged_attention_v2( | ||||
|     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, | ||||
|     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, | ||||
|     torch::Tensor& value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, | ||||
|     int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, | ||||
|     const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, | ||||
|     const int blocksparse_local_blocks, const int blocksparse_vert_stride, | ||||
|     const int blocksparse_block_size, const int blocksparse_head_sliding_step) { | ||||
|   TORCH_CHECK(kv_scale == 1.0f); | ||||
|   TORCH_CHECK(blocksparse_vert_stride <= 1, | ||||
|               "CPU backend does not support blocksparse attention yet."); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", | ||||
|                                [&] { | ||||
|                                  CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) | ||||
|  | ||||
| @ -5,25 +5,26 @@ | ||||
|  | ||||
| namespace { | ||||
| template <typename scalar_t> | ||||
| void copy_blocks_cpu_impl( | ||||
|     std::vector<torch::Tensor> &key_caches, | ||||
|     std::vector<torch::Tensor> &value_caches, | ||||
|     const std::vector<std::pair<int64_t, int64_t>> mapping_pairs, | ||||
|     const int element_num_per_block, const int layer_num) { | ||||
|   const size_t pair_num = mapping_pairs.size(); | ||||
| void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches, | ||||
|                           std::vector<torch::Tensor>& value_caches, | ||||
|                           const torch::Tensor& mapping_pairs, | ||||
|                           const int element_num_per_block, | ||||
|                           const int layer_num) { | ||||
|   const size_t pair_num = mapping_pairs.size(0); | ||||
|   const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; | ||||
| #pragma omp parallel for collapse(2) | ||||
|   for (int layer = 0; layer < layer_num; ++layer) { | ||||
|     for (size_t pair = 0; pair < pair_num; ++pair) { | ||||
|       int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; | ||||
|       int64_t source_offset = | ||||
|           element_num_per_block * mapping_pairs[pair][0].item<int64_t>(); | ||||
|       int64_t target_offset = | ||||
|           element_num_per_block * mapping_pairs[pair].second; | ||||
|       scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); | ||||
|       scalar_t *source_ptr = key_cache_ptr + source_offset; | ||||
|       scalar_t *target_ptr = key_cache_ptr + target_offset; | ||||
|           element_num_per_block * mapping_pairs[pair][1].item<int64_t>(); | ||||
|       scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); | ||||
|       scalar_t* source_ptr = key_cache_ptr + source_offset; | ||||
|       scalar_t* target_ptr = key_cache_ptr + target_offset; | ||||
|       std::memcpy(target_ptr, source_ptr, block_bytes); | ||||
|  | ||||
|       scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>(); | ||||
|       scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>(); | ||||
|       source_ptr = value_cache_ptr + source_offset; | ||||
|       target_ptr = value_cache_ptr + target_offset; | ||||
|       std::memcpy(target_ptr, source_ptr, block_bytes); | ||||
| @ -33,9 +34,9 @@ void copy_blocks_cpu_impl( | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void reshape_and_cache_cpu_impl( | ||||
|     const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, | ||||
|     scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, | ||||
|     const int64_t *__restrict__ slot_mapping, const int num_tokens, | ||||
|     const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, | ||||
|     scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, | ||||
|     const int64_t* __restrict__ slot_mapping, const int num_tokens, | ||||
|     const int key_stride, const int value_stride, const int num_heads, | ||||
|     const int head_size, const int block_size, const int x) { | ||||
|   const int block_elem_num = num_heads * head_size * block_size; | ||||
| @ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( | ||||
|         int src_key_head_idx = token_idx * key_stride + head_idx * head_size; | ||||
|         int src_value_head_idx = | ||||
|             token_idx * value_stride + head_idx * head_size; | ||||
|         const scalar_t *src_key_head_ptr = key + src_key_head_idx; | ||||
|         const scalar_t *src_value_head_ptr = value + src_value_head_idx; | ||||
|         const scalar_t* src_key_head_ptr = key + src_key_head_idx; | ||||
|         const scalar_t* src_value_head_ptr = value + src_value_head_idx; | ||||
|         const int64_t block_index = slot_idx / block_size; | ||||
|         const int64_t block_offset = slot_idx % block_size; | ||||
|         scalar_t *target_key_head_ptr = key_cache + | ||||
|         scalar_t* target_key_head_ptr = key_cache + | ||||
|                                         block_elem_num * block_index + | ||||
|                                         head_idx * block_size * head_size; | ||||
|         scalar_t *target_value_head_ptr = value_cache + | ||||
|         scalar_t* target_value_head_ptr = value_cache + | ||||
|                                           block_elem_num * block_index + | ||||
|                                           head_idx * block_size * head_size; | ||||
|  | ||||
| @ -79,39 +80,31 @@ void reshape_and_cache_cpu_impl( | ||||
|     } | ||||
|   } | ||||
| } | ||||
| }; // namespace | ||||
| };  // namespace | ||||
|  | ||||
| void copy_blocks(std::vector<torch::Tensor> &key_caches, | ||||
|                  std::vector<torch::Tensor> &value_caches, | ||||
|                  const std::map<int64_t, std::vector<int64_t>> &block_mapping) { | ||||
|   int num_layers = key_caches.size(); | ||||
| void copy_blocks(std::vector<torch::Tensor>& key_caches, | ||||
|                  std::vector<torch::Tensor>& value_caches, | ||||
|                  const torch::Tensor& block_mapping) { | ||||
|   unsigned num_layers = key_caches.size(); | ||||
|   TORCH_CHECK(num_layers == value_caches.size()); | ||||
|   if (num_layers == 0) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   std::vector<std::pair<int64_t, int64_t>> mapping_pairs; | ||||
|   mapping_pairs.reserve(block_mapping.size()); | ||||
|   for (const auto &pair : block_mapping) { | ||||
|     for (const auto &dst : pair.second) { | ||||
|       mapping_pairs.emplace_back(pair.first, dst); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   const int element_num_per_block = key_caches[0][0].numel(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { | ||||
|         CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) | ||||
|         copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs, | ||||
|         copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping, | ||||
|                                        element_num_per_block, num_layers); | ||||
|         CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, | ||||
|                        torch::Tensor &key_cache, torch::Tensor &value_cache, | ||||
|                        torch::Tensor &slot_mapping, | ||||
|                        const std::string &kv_cache_dtype, float kv_scale) { | ||||
| void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, | ||||
|                        torch::Tensor& key_cache, torch::Tensor& value_cache, | ||||
|                        torch::Tensor& slot_mapping, | ||||
|                        const std::string& kv_cache_dtype, float kv_scale) { | ||||
|   TORCH_CHECK(kv_scale == 1.0f); | ||||
|  | ||||
|   int num_tokens = key.size(0); | ||||
| @ -135,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void swap_blocks(torch::Tensor &src, torch::Tensor &dst, | ||||
|                  const std::map<int64_t, int64_t> &block_mapping) { | ||||
| void swap_blocks(torch::Tensor& src, torch::Tensor& dst, | ||||
|                  const torch::Tensor& block_mapping) { | ||||
|   TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") | ||||
| } | ||||
|  | ||||
| @ -2,10 +2,10 @@ | ||||
|  | ||||
| namespace { | ||||
| template <typename scalar_t> | ||||
| void rms_norm_impl(scalar_t *__restrict__ out, | ||||
|                        const scalar_t *__restrict__ input, | ||||
|                        const scalar_t *__restrict__ weight, const float epsilon, | ||||
|                        const int num_tokens, const int hidden_size) { | ||||
| void rms_norm_impl(scalar_t* __restrict__ out, | ||||
|                    const scalar_t* __restrict__ input, | ||||
|                    const scalar_t* __restrict__ weight, const float epsilon, | ||||
|                    const int num_tokens, const int hidden_size) { | ||||
|   using scalar_vec_t = vec_op::vec_t<scalar_t>; | ||||
|   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); | ||||
|   TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); | ||||
| @ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void fused_add_rms_norm_impl(scalar_t *__restrict__ input, | ||||
|                                  scalar_t *__restrict__ residual, | ||||
|                                  const scalar_t *__restrict__ weight, | ||||
|                                  const float epsilon, const int num_tokens, | ||||
|                                  const int hidden_size) { | ||||
| void fused_add_rms_norm_impl(scalar_t* __restrict__ input, | ||||
|                              scalar_t* __restrict__ residual, | ||||
|                              const scalar_t* __restrict__ weight, | ||||
|                              const float epsilon, const int num_tokens, | ||||
|                              const int hidden_size) { | ||||
|   using scalar_vec_t = vec_op::vec_t<scalar_t>; | ||||
|   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); | ||||
|   TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); | ||||
| @ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, | ||||
|     } | ||||
|   } | ||||
| } | ||||
| } // namespace | ||||
| }  // namespace | ||||
|  | ||||
| void rms_norm(torch::Tensor &out, torch::Tensor &input, | ||||
|                   torch::Tensor &weight, float epsilon) { | ||||
| void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, | ||||
|               float epsilon) { | ||||
|   int hidden_size = input.size(-1); | ||||
|   int num_tokens = input.numel() / hidden_size; | ||||
|  | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { | ||||
|     CPU_KERNEL_GUARD_IN(rms_norm_impl) | ||||
|     rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), | ||||
|                       weight.data_ptr<scalar_t>(), epsilon, num_tokens, | ||||
|                       hidden_size); | ||||
|                   weight.data_ptr<scalar_t>(), epsilon, num_tokens, | ||||
|                   hidden_size); | ||||
|     CPU_KERNEL_GUARD_OUT(rms_norm_impl) | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, | ||||
|                             torch::Tensor &weight, float epsilon) { | ||||
| void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, | ||||
|                         torch::Tensor& weight, float epsilon) { | ||||
|   int hidden_size = input.size(-1); | ||||
|   int num_tokens = input.numel() / hidden_size; | ||||
|  | ||||
|  | ||||
| @ -4,107 +4,107 @@ | ||||
| namespace { | ||||
| template <typename scalar_t> | ||||
| void rotary_embedding_impl( | ||||
|     const int64_t | ||||
|         *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] | ||||
|     scalar_t | ||||
|         *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or | ||||
|                              /// [num_tokens, num_heads, head_size] | ||||
|     scalar_t | ||||
|         *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or | ||||
|                            // [num_tokens, num_kv_heads, head_size] | ||||
|     const scalar_t | ||||
|         *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] | ||||
|     const int64_t* __restrict__ positions,  // [batch_size, seq_len] or | ||||
|                                             // [num_tokens] | ||||
|     scalar_t* __restrict__ query,           /// [batch_size, seq_len, num_heads, | ||||
|                                    /// head_size] or [num_tokens, num_heads, | ||||
|                                    /// head_size] | ||||
|     scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads, | ||||
|                                  // head_size] or [num_tokens, num_kv_heads, | ||||
|                                  // head_size] | ||||
|     const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim // | ||||
|                                                  // 2] | ||||
|     const int rot_dim, const int64_t query_stride, const int64_t key_stride, | ||||
|     const int num_heads, const int num_kv_heads, const int head_size, | ||||
|     const int num_tokens) { | ||||
|   using scalar_vec_t = vec_op::vec_t<scalar_t>; | ||||
|   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); | ||||
|   constexpr int ELEM_SIZE = sizeof(scalar_t); | ||||
|  | ||||
|   const int embed_dim = rot_dim / 2; | ||||
|   TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); | ||||
|   bool flag = (embed_dim % VEC_ELEM_NUM == 0); | ||||
|   const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; | ||||
|  | ||||
|   auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, | ||||
|                           scalar_t* qk) { | ||||
|     int j = 0; | ||||
|     for (; j < loop_upper; j += VEC_ELEM_NUM) { | ||||
|       const int rot_offset = j; | ||||
|       const int x_index = rot_offset; | ||||
|       const int y_index = embed_dim + rot_offset; | ||||
|  | ||||
|       const int64_t out_x = token_head + x_index; | ||||
|       const int64_t out_y = token_head + y_index; | ||||
|  | ||||
|       const scalar_vec_t cos(cache_ptr + x_index); | ||||
|       const scalar_vec_t sin(cache_ptr + y_index); | ||||
|  | ||||
|       const scalar_vec_t q_x(qk + out_x); | ||||
|       const scalar_vec_t q_y(qk + out_y); | ||||
|  | ||||
|       vec_op::FP32Vec8 fp32_cos(cos); | ||||
|       vec_op::FP32Vec8 fp32_sin(sin); | ||||
|  | ||||
|       vec_op::FP32Vec8 fp32_q_x(q_x); | ||||
|       vec_op::FP32Vec8 fp32_q_y(q_y); | ||||
|  | ||||
|       auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; | ||||
|       scalar_vec_t(out1).save(qk + out_x); | ||||
|  | ||||
|       auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; | ||||
|       scalar_vec_t(out2).save(qk + out_y); | ||||
|     } | ||||
|     if (!flag) { | ||||
|       for (; j < embed_dim; ++j) { | ||||
|         const int x_index = j; | ||||
|         const int y_index = embed_dim + j; | ||||
|  | ||||
|         const int64_t out_x = token_head + x_index; | ||||
|         const int64_t out_y = token_head + y_index; | ||||
|  | ||||
|         const float fp32_cos = cache_ptr[x_index]; | ||||
|         const float fp32_sin = cache_ptr[y_index]; | ||||
|  | ||||
|         const float fp32_q_x = qk[out_x]; | ||||
|         const float fp32_q_y = qk[out_y]; | ||||
|  | ||||
|         qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; | ||||
|         qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; | ||||
|       } | ||||
|     } | ||||
|   }; | ||||
|  | ||||
| #pragma omp parallel for | ||||
|   for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { | ||||
|     int64_t pos = positions[token_idx]; | ||||
|     const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; | ||||
|     const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; | ||||
|  | ||||
|     for (int i = 0; i < num_heads; ++i) { | ||||
|       const int head_idx = i; | ||||
|       const int64_t token_head = | ||||
|           token_idx * query_stride + head_idx * head_size; | ||||
|       for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { | ||||
|         const int rot_offset = j; | ||||
|         const int x_index = rot_offset; | ||||
|         const int y_index = embed_dim + rot_offset; | ||||
|  | ||||
|         const int64_t out_x = token_head + x_index; | ||||
|         const int64_t out_y = token_head + y_index; | ||||
|  | ||||
|         const scalar_vec_t cos(cache_ptr + x_index); | ||||
|         const scalar_vec_t sin(cache_ptr + y_index); | ||||
|  | ||||
|         const scalar_vec_t q_x(query + out_x); | ||||
|         const scalar_vec_t q_y(query + out_y); | ||||
|  | ||||
|         vec_op::FP32Vec8 fp32_cos(cos); | ||||
|         vec_op::FP32Vec8 fp32_sin(sin); | ||||
|  | ||||
|         vec_op::FP32Vec8 fp32_q_x(q_x); | ||||
|         vec_op::FP32Vec8 fp32_q_y(q_y); | ||||
|  | ||||
|         auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; | ||||
|         scalar_vec_t(out1).save(query + out_x); | ||||
|  | ||||
|         auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; | ||||
|         scalar_vec_t(out2).save(query + out_y); | ||||
|       } | ||||
|       compute_loop(token_head, cache_ptr, query); | ||||
|     } | ||||
|  | ||||
|     for (int i = 0; i < num_kv_heads; ++i) { | ||||
|       const int head_idx = i; | ||||
|       const int64_t token_head = token_idx * key_stride + head_idx * head_size; | ||||
|       for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { | ||||
|         const int rot_offset = j; | ||||
|         const int x_index = rot_offset; | ||||
|         const int y_index = embed_dim + rot_offset; | ||||
|  | ||||
|         const int64_t out_x = token_head + x_index; | ||||
|         const int64_t out_y = token_head + y_index; | ||||
|  | ||||
|         const scalar_vec_t cos(cache_ptr + x_index); | ||||
|         const scalar_vec_t sin(cache_ptr + y_index); | ||||
|  | ||||
|         const scalar_vec_t k_x(key + out_x); | ||||
|         const scalar_vec_t k_y(key + out_y); | ||||
|  | ||||
|         vec_op::FP32Vec8 fp32_cos(cos); | ||||
|         vec_op::FP32Vec8 fp32_sin(sin); | ||||
|  | ||||
|         vec_op::FP32Vec8 fp32_k_x(k_x); | ||||
|         vec_op::FP32Vec8 fp32_k_y(k_y); | ||||
|  | ||||
|         auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; | ||||
|         scalar_vec_t(out1).save(key + out_x); | ||||
|         auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; | ||||
|         scalar_vec_t(out2).save(key + out_y); | ||||
|       } | ||||
|       compute_loop(token_head, cache_ptr, key); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void rotary_embedding_gptj_impl( | ||||
|     const int64_t | ||||
|         *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] | ||||
|     scalar_t | ||||
|         *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or | ||||
|                              /// [num_tokens, num_heads, head_size] | ||||
|     scalar_t | ||||
|         *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or | ||||
|                            // [num_tokens, num_kv_heads, head_size] | ||||
|     const scalar_t | ||||
|         *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] | ||||
|     const int64_t* __restrict__ positions,  // [batch_size, seq_len] or | ||||
|                                             // [num_tokens] | ||||
|     scalar_t* __restrict__ query,           /// [batch_size, seq_len, num_heads, | ||||
|                                    /// head_size] or [num_tokens, num_heads, | ||||
|                                    /// head_size] | ||||
|     scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads, | ||||
|                                  // head_size] or [num_tokens, num_kv_heads, | ||||
|                                  // head_size] | ||||
|     const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim // | ||||
|                                                  // 2] | ||||
|     const int rot_dim, const int64_t query_stride, const int64_t key_stride, | ||||
|     const int num_heads, const int num_kv_heads, const int head_size, | ||||
|     const int num_tokens) { | ||||
| @ -114,13 +114,13 @@ void rotary_embedding_gptj_impl( | ||||
|   for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { | ||||
|     for (int i = 0; i < num_heads; ++i) { | ||||
|       int64_t pos = positions[token_idx]; | ||||
|       const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; | ||||
|       const scalar_t *cos_cache_ptr = cache_ptr; | ||||
|       const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; | ||||
|       const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; | ||||
|       const scalar_t* cos_cache_ptr = cache_ptr; | ||||
|       const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; | ||||
|       const int head_idx = i; | ||||
|       const int64_t token_head = | ||||
|           token_idx * query_stride + head_idx * head_size; | ||||
|       scalar_t *head_query = token_head + query; | ||||
|       scalar_t* head_query = token_head + query; | ||||
|       for (int j = 0; j < embed_dim; j += 1) { | ||||
|         const int rot_offset = j; | ||||
|         const int x_index = 2 * rot_offset; | ||||
| @ -142,12 +142,12 @@ void rotary_embedding_gptj_impl( | ||||
|   for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { | ||||
|     for (int i = 0; i < num_kv_heads; ++i) { | ||||
|       int64_t pos = positions[token_idx]; | ||||
|       const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; | ||||
|       const scalar_t *cos_cache_ptr = cache_ptr; | ||||
|       const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; | ||||
|       const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; | ||||
|       const scalar_t* cos_cache_ptr = cache_ptr; | ||||
|       const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; | ||||
|       const int head_idx = i; | ||||
|       const int64_t token_head = token_idx * key_stride + head_idx * head_size; | ||||
|       scalar_t *head_key = key + token_head; | ||||
|       scalar_t* head_key = key + token_head; | ||||
|       for (int j = 0; j < embed_dim; j += 1) { | ||||
|         const int rot_offset = j; | ||||
|         const int x_index = 2 * rot_offset; | ||||
| @ -165,11 +165,11 @@ void rotary_embedding_gptj_impl( | ||||
|     } | ||||
|   } | ||||
| } | ||||
| }; // namespace | ||||
| };  // namespace | ||||
|  | ||||
| void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, | ||||
|                           torch::Tensor &key, int head_size, | ||||
|                           torch::Tensor &cos_sin_cache, bool is_neox) { | ||||
| void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, | ||||
|                       torch::Tensor& key, int head_size, | ||||
|                       torch::Tensor& cos_sin_cache, bool is_neox) { | ||||
|   int num_tokens = query.numel() / query.size(-1); | ||||
|   int rot_dim = cos_sin_cache.size(1); | ||||
|   int num_heads = query.size(-1) / head_size; | ||||
|  | ||||
| @ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
|   pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); | ||||
|  | ||||
|   // Attention ops | ||||
|   ops.def( | ||||
|     "paged_attention_v1", | ||||
|     &paged_attention_v1, | ||||
|     "Compute the attention between an input query and the cached keys/values using PagedAttention."); | ||||
|   ops.def( | ||||
|     "paged_attention_v2", | ||||
|     &paged_attention_v2, | ||||
|     "PagedAttention V2."); | ||||
|   ops.def("paged_attention_v1", &paged_attention_v1, | ||||
|           "Compute the attention between an input query and the cached " | ||||
|           "keys/values using PagedAttention."); | ||||
|   ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); | ||||
|  | ||||
|   // Activation ops | ||||
|   ops.def( | ||||
|     "silu_and_mul", | ||||
|     &silu_and_mul, | ||||
|     "Activation function used in SwiGLU."); | ||||
|   ops.def( | ||||
|     "gelu_and_mul", | ||||
|     &gelu_and_mul, | ||||
|     "Activation function used in GeGLU with `none` approximation."); | ||||
|   ops.def( | ||||
|     "gelu_tanh_and_mul", | ||||
|     &gelu_tanh_and_mul, | ||||
|     "Activation function used in GeGLU with `tanh` approximation."); | ||||
|   ops.def( | ||||
|     "gelu_new", | ||||
|     &gelu_new, | ||||
|     "GELU implementation used in GPT-2."); | ||||
|   ops.def( | ||||
|     "gelu_fast", | ||||
|     &gelu_fast, | ||||
|     "Approximate GELU implementation."); | ||||
|   ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); | ||||
|   ops.def("gelu_and_mul", &gelu_and_mul, | ||||
|           "Activation function used in GeGLU with `none` approximation."); | ||||
|   ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, | ||||
|           "Activation function used in GeGLU with `tanh` approximation."); | ||||
|   ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); | ||||
|   ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); | ||||
|  | ||||
|   // Layernorm | ||||
|   ops.def( | ||||
|     "rms_norm", | ||||
|     &rms_norm, | ||||
|     "Apply Root Mean Square (RMS) Normalization to the input tensor."); | ||||
|   ops.def("rms_norm", &rms_norm, | ||||
|           "Apply Root Mean Square (RMS) Normalization to the input tensor."); | ||||
|  | ||||
|   ops.def( | ||||
|     "fused_add_rms_norm", | ||||
|     &fused_add_rms_norm, | ||||
|     "In-place fused Add and RMS Normalization"); | ||||
|   ops.def("fused_add_rms_norm", &fused_add_rms_norm, | ||||
|           "In-place fused Add and RMS Normalization"); | ||||
|  | ||||
|   // Rotary embedding | ||||
|   ops.def( | ||||
|     "rotary_embedding", | ||||
|     &rotary_embedding, | ||||
|     "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); | ||||
|   ops.def("rotary_embedding", &rotary_embedding, | ||||
|           "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); | ||||
|  | ||||
|   // Cache ops | ||||
|   pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); | ||||
|   cache_ops.def( | ||||
|     "swap_blocks", | ||||
|     &swap_blocks, | ||||
|     "Swap in (out) the cache blocks from src to dst"); | ||||
|   cache_ops.def( | ||||
|     "copy_blocks", | ||||
|     ©_blocks, | ||||
|     "Copy the cache blocks from src to dst"); | ||||
|   cache_ops.def( | ||||
|     "reshape_and_cache", | ||||
|     &reshape_and_cache, | ||||
|     "Reshape the key and value tensors and cache them"); | ||||
|   cache_ops.def("swap_blocks", &swap_blocks, | ||||
|                 "Swap in (out) the cache blocks from src to dst"); | ||||
|   cache_ops.def("copy_blocks", ©_blocks, | ||||
|                 "Copy the cache blocks from src to dst"); | ||||
|   cache_ops.def("reshape_and_cache", &reshape_and_cache, | ||||
|                 "Reshape the key and value tensors and cache them"); | ||||
| } | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| #pragma once | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| #include <hip/hip_runtime.h> | ||||
|   #include <hip/hip_runtime.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
| @ -17,9 +17,14 @@ | ||||
| #endif | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) | ||||
|   #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ | ||||
|     __shfl_xor_sync(uint32_t(-1), var, lane_mask) | ||||
|   #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ | ||||
|     __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) | ||||
| #else | ||||
|   #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) | ||||
|   #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ | ||||
|     __shfl_xor(var, lane_mask, width) | ||||
| #endif | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
| @ -28,6 +33,13 @@ | ||||
|   #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) | ||||
| #endif | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ | ||||
|     __shfl_down_sync(uint32_t(-1), var, lane_delta) | ||||
| #else | ||||
|   #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) | ||||
| #endif | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ | ||||
|     cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) | ||||
| @ -35,4 +47,3 @@ | ||||
|   #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ | ||||
|     hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) | ||||
| #endif | ||||
|  | ||||
|  | ||||
| @ -2,9 +2,6 @@ | ||||
|  | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| int get_device_attribute( | ||||
|     int attribute, | ||||
|     int device_id); | ||||
| int get_device_attribute(int attribute, int device_id); | ||||
|  | ||||
| int get_max_shared_memory_per_block_device_attribute( | ||||
|     int device_id); | ||||
| int get_max_shared_memory_per_block_device_attribute(int device_id); | ||||
|  | ||||
| @ -2,34 +2,28 @@ | ||||
|   #include <hip/hip_runtime.h> | ||||
|   #include <hip/hip_runtime_api.h> | ||||
| #endif | ||||
| int get_device_attribute( | ||||
|     int attribute, | ||||
|     int device_id) | ||||
| { | ||||
|     int device, value; | ||||
|     if (device_id < 0) { | ||||
|         cudaGetDevice(&device); | ||||
|     } | ||||
|     else { | ||||
|         device = device_id; | ||||
|     } | ||||
|     cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device); | ||||
|     return value; | ||||
| int get_device_attribute(int attribute, int device_id) { | ||||
|   int device, value; | ||||
|   if (device_id < 0) { | ||||
|     cudaGetDevice(&device); | ||||
|   } else { | ||||
|     device = device_id; | ||||
|   } | ||||
|   cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), | ||||
|                          device); | ||||
|   return value; | ||||
| } | ||||
|  | ||||
|  | ||||
| int get_max_shared_memory_per_block_device_attribute( | ||||
|     int device_id) | ||||
| { | ||||
| int attribute;     | ||||
| // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html | ||||
| // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 | ||||
| int get_max_shared_memory_per_block_device_attribute(int device_id) { | ||||
|   int attribute; | ||||
|   // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html | ||||
|   // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|     attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; | ||||
|   attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; | ||||
| #else | ||||
|     attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; | ||||
|   attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; | ||||
| #endif | ||||
|  | ||||
|     return get_device_attribute(attribute, device_id); | ||||
|   return get_device_attribute(attribute, device_id); | ||||
| } | ||||
|  | ||||
| @ -7,11 +7,11 @@ | ||||
|  | ||||
| // fake pointer type | ||||
| using fptr_t = uint64_t; | ||||
| static_assert(sizeof(void *) == sizeof(fptr_t)); | ||||
| static_assert(sizeof(void*) == sizeof(fptr_t)); | ||||
|  | ||||
| fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, | ||||
|                       const std::vector<std::string> &handles, | ||||
|                       const std::vector<int64_t> &offsets, int rank, | ||||
| fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, | ||||
|                       const std::vector<std::string>& handles, | ||||
|                       const std::vector<int64_t>& offsets, int rank, | ||||
|                       bool full_nvlink) { | ||||
|   int world_size = offsets.size(); | ||||
|   if (world_size > 8) | ||||
| @ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, | ||||
|     std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); | ||||
|   } | ||||
|   return (fptr_t) new vllm::CustomAllreduce( | ||||
|       reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(), | ||||
|       reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(), | ||||
|       rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); | ||||
| } | ||||
|  | ||||
| @ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, | ||||
|  * 5. A[None].expand(2, -1, -1, -1): Not OK | ||||
|  * 6. A[:, 1:, 1:]: Not OK | ||||
|  */ | ||||
| bool _is_weak_contiguous(torch::Tensor &t) { | ||||
| bool _is_weak_contiguous(torch::Tensor& t) { | ||||
|   return t.is_contiguous() || | ||||
|          (t.storage().nbytes() - t.storage_offset() * t.element_size() == | ||||
|           t.numel() * t.element_size()); | ||||
| } | ||||
|  | ||||
| bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, | ||||
| bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, | ||||
|                       bool full_nvlink) { | ||||
|   auto inp_size = inp.numel() * inp.element_size(); | ||||
|   // custom allreduce requires input byte size to be multiples of 16 | ||||
| @ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, | ||||
| void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, | ||||
|                  cudaStream_t stream) { | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); | ||||
|   TORCH_CHECK(_is_weak_contiguous(out)); | ||||
|   switch (out.scalar_type()) { | ||||
|     case at::ScalarType::Float: { | ||||
|       fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()), | ||||
|                            reinterpret_cast<float *>(out.data_ptr()), | ||||
|       fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()), | ||||
|                            reinterpret_cast<float*>(out.data_ptr()), | ||||
|                            out.numel()); | ||||
|       break; | ||||
|     } | ||||
|     case at::ScalarType::Half: { | ||||
|       fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()), | ||||
|                           reinterpret_cast<half *>(out.data_ptr()), | ||||
|                           out.numel()); | ||||
|       fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()), | ||||
|                           reinterpret_cast<half*>(out.data_ptr()), out.numel()); | ||||
|       break; | ||||
|     } | ||||
| #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) | ||||
|     case at::ScalarType::BFloat16: { | ||||
|       fa->allreduce<nv_bfloat16>( | ||||
|           stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()), | ||||
|           reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel()); | ||||
|           stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()), | ||||
|           reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel()); | ||||
|       break; | ||||
|     } | ||||
| #endif | ||||
| @ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, | ||||
|   } | ||||
| } | ||||
|  | ||||
| void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { | ||||
| void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); | ||||
|   auto stream = c10::cuda::getCurrentCUDAStream().stream(); | ||||
|   TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); | ||||
| @ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { | ||||
|   _all_reduce(_fa, inp, out, stream); | ||||
| } | ||||
|  | ||||
| void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, | ||||
|                       torch::Tensor &out) { | ||||
| void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, | ||||
|                       torch::Tensor& out) { | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); | ||||
|   auto stream = c10::cuda::getCurrentCUDAStream().stream(); | ||||
|  | ||||
| @ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, | ||||
| } | ||||
|  | ||||
| void dispose(fptr_t _fa) { | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); | ||||
|   delete fa; | ||||
| } | ||||
|  | ||||
| int meta_size() { return sizeof(vllm::Signal); } | ||||
|  | ||||
| void register_buffer(fptr_t _fa, torch::Tensor &t, | ||||
|                      const std::vector<std::string> &handles, | ||||
|                      const std::vector<int64_t> &offsets) { | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||
| void register_buffer(fptr_t _fa, torch::Tensor& t, | ||||
|                      const std::vector<std::string>& handles, | ||||
|                      const std::vector<int64_t>& offsets) { | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); | ||||
|   fa->register_buffer(handles, offsets, t.data_ptr()); | ||||
| } | ||||
|  | ||||
| std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( | ||||
|     fptr_t _fa) { | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); | ||||
|   return fa->get_graph_buffer_ipc_meta(); | ||||
| } | ||||
|  | ||||
| void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, | ||||
|                             const std::vector<std::vector<int64_t>> &offsets) { | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||
| void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, | ||||
|                             const std::vector<std::vector<int64_t>>& offsets) { | ||||
|   auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); | ||||
|   fa->register_graph_buffers(handles, offsets); | ||||
| } | ||||
|  | ||||
| @ -31,9 +31,9 @@ struct Signal { | ||||
|   alignas(128) uint32_t end[kMaxBlocks][8]; | ||||
| }; | ||||
|  | ||||
| struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; | ||||
| struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; | ||||
|  | ||||
| struct __align__(16) RankSignals { volatile Signal *signals[8]; }; | ||||
| struct __align__(16) RankSignals { volatile Signal* signals[8]; }; | ||||
|  | ||||
| // like std::array, but aligned | ||||
| template <typename T, int sz> | ||||
| @ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { | ||||
| // scalar add functions | ||||
| // for some reason when compiling with Pytorch, the + operator for half and | ||||
| // bfloat is disabled so we call the intrinsics directly | ||||
| DINLINE half &assign_add(half &a, half b) { | ||||
| DINLINE half& assign_add(half& a, half b) { | ||||
|   a = __hadd(a, b); | ||||
|   return a; | ||||
| } | ||||
| DINLINE float &assign_add(float &a, float b) { return a += b; } | ||||
| DINLINE float& assign_add(float& a, float b) { return a += b; } | ||||
|  | ||||
| #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) | ||||
| DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } | ||||
| @ -80,14 +80,14 @@ template <> | ||||
| DINLINE nv_bfloat16 downcast_s(float val) { | ||||
|   return __float2bfloat16(val); | ||||
| } | ||||
| DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { | ||||
| DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { | ||||
|   a = __hadd(a, b); | ||||
|   return a; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| template <typename T, int N> | ||||
| DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) { | ||||
| DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) { | ||||
| #pragma unroll | ||||
|   for (int i = 0; i < N; i++) { | ||||
|     assign_add(a.data[i], b.data[i]); | ||||
| @ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) { | ||||
| // prior memory accesses. Note: volatile writes will not be reordered against | ||||
| // other volatile writes. | ||||
| template <int ngpus> | ||||
| DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, | ||||
| DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, | ||||
|                         int rank) { | ||||
|   if (threadIdx.x < ngpus) { | ||||
|     // reset flag for next time | ||||
| @ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, | ||||
|     // Latency = 1 p2p write | ||||
|     sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; | ||||
|     // wait until we got true from all ranks | ||||
|     while (!self_sg->start[blockIdx.x][threadIdx.x]) | ||||
|       ; | ||||
|     while (!self_sg->start[blockIdx.x][threadIdx.x]); | ||||
|   } | ||||
|   __syncthreads(); | ||||
| } | ||||
| @ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, | ||||
| // barrier in the all reduce kernel. If it's the final synchronization barrier, | ||||
| // we don't need to make any visibility guarantees for prior memory accesses. | ||||
| template <int ngpus, bool final_sync = false> | ||||
| DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, | ||||
| DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, | ||||
|                       int rank) { | ||||
|   __syncthreads(); | ||||
|   // eliminate the case that prior writes are not visible after signals become | ||||
|   // visible. Note that I did not managed to make this happen through a lot of | ||||
|   // testing. Might be the case that hardware provides stronger guarantee than | ||||
|   // the memory model.  | ||||
|   // the memory model. | ||||
|   if constexpr (!final_sync) __threadfence_system(); | ||||
|   if (threadIdx.x < ngpus) { | ||||
|     // reset flag for next time | ||||
| @ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, | ||||
|     // Latency = 1 p2p write | ||||
|     sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; | ||||
|     // wait until we got true from all ranks | ||||
|     while (!self_sg->end[blockIdx.x][threadIdx.x]) | ||||
|       ; | ||||
|     while (!self_sg->end[blockIdx.x][threadIdx.x]); | ||||
|   } | ||||
|   if constexpr (!final_sync) __syncthreads(); | ||||
| } | ||||
|  | ||||
| template <typename P, int ngpus, typename A> | ||||
| DINLINE P packed_reduce(const P *ptrs[], int idx) { | ||||
| DINLINE P packed_reduce(const P* ptrs[], int idx) { | ||||
|   A tmp = upcast(ptrs[0][idx]); | ||||
| #pragma unroll | ||||
|   for (int i = 1; i < ngpus; i++) { | ||||
| @ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { | ||||
|  | ||||
| template <typename T, int ngpus> | ||||
| __global__ void __launch_bounds__(512, 1) | ||||
|     cross_device_reduce_1stage(RankData *_dp, RankSignals sg, | ||||
|                                volatile Signal *self_sg, T *__restrict__ result, | ||||
|     cross_device_reduce_1stage(RankData* _dp, RankSignals sg, | ||||
|                                volatile Signal* self_sg, T* __restrict__ result, | ||||
|                                int rank, int size) { | ||||
|   using P = typename packed_t<T>::P; | ||||
|   using A = typename packed_t<T>::A; | ||||
| @ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) | ||||
|   // do the actual reduction | ||||
|   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||
|        idx += gridDim.x * blockDim.x) { | ||||
|     ((P *)result)[idx] = | ||||
|         packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx); | ||||
|     ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx); | ||||
|   } | ||||
|   end_sync<ngpus, true>(sg, self_sg, rank); | ||||
| } | ||||
|  | ||||
| template <typename P> | ||||
| DINLINE P *get_tmp_buf(volatile Signal *sg) { | ||||
|   return (P *)(((Signal *)sg) + 1); | ||||
| DINLINE P* get_tmp_buf(volatile Signal* sg) { | ||||
|   return (P*)(((Signal*)sg) + 1); | ||||
| } | ||||
|  | ||||
| template <typename T, int ngpus> | ||||
| __global__ void __launch_bounds__(512, 1) | ||||
|     cross_device_reduce_2stage(RankData *_dp, RankSignals sg, | ||||
|                                volatile Signal *self_sg, T *__restrict__ result, | ||||
|     cross_device_reduce_2stage(RankData* _dp, RankSignals sg, | ||||
|                                volatile Signal* self_sg, T* __restrict__ result, | ||||
|                                int rank, int size) { | ||||
|   int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|   int stride = gridDim.x * blockDim.x; | ||||
| @ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) | ||||
|   int start = rank * part; | ||||
|   int end = rank == ngpus - 1 ? size : start + part; | ||||
|   int largest_part = part + size % ngpus; | ||||
|   const P *ptrs[ngpus]; | ||||
|   P *tmps[ngpus]; | ||||
|   const P* ptrs[ngpus]; | ||||
|   P* tmps[ngpus]; | ||||
| #pragma unroll | ||||
|   for (int i = 0; i < ngpus; i++) { | ||||
|     int target = (rank + i) % ngpus; | ||||
|     ptrs[i] = (const P *)_dp->ptrs[target]; | ||||
|     ptrs[i] = (const P*)_dp->ptrs[target]; | ||||
|     tmps[i] = get_tmp_buf<P>(sg.signals[target]); | ||||
|   } | ||||
|   auto tmp_out = tmps[0]; | ||||
| @ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) | ||||
|       int gather_from_rank = ((rank + i) % ngpus); | ||||
|       if (gather_from_rank == ngpus - 1 || idx < part) { | ||||
|         int dst_idx = gather_from_rank * part + idx; | ||||
|         ((P *)result)[dst_idx] = tmps[i][idx]; | ||||
|         ((P*)result)[dst_idx] = tmps[i][idx]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| @ -261,14 +258,14 @@ class CustomAllreduce { | ||||
|  | ||||
|   // below are device pointers | ||||
|   RankSignals sg_; | ||||
|   std::unordered_map<void *, RankData *> buffers_; | ||||
|   Signal *self_sg_; | ||||
|   std::unordered_map<void*, RankData*> buffers_; | ||||
|   Signal* self_sg_; | ||||
|  | ||||
|   // stores the registered device pointers from all ranks | ||||
|   RankData *d_rank_data_base_, *d_rank_data_end_; | ||||
|   std::vector<void *> graph_unreg_buffers_; | ||||
|   std::vector<void*> graph_unreg_buffers_; | ||||
|   // a map from IPC handles to opened IPC pointers | ||||
|   std::map<IPC_KEY, char *> ipc_handles_; | ||||
|   std::map<IPC_KEY, char*> ipc_handles_; | ||||
|  | ||||
|   /** | ||||
|    * meta is a pointer to device metadata and temporary buffer for allreduce. | ||||
| @ -279,22 +276,22 @@ class CustomAllreduce { | ||||
|    * note: this class does not own any device memory. Any required buffers | ||||
|    * are passed in from the constructor | ||||
|    */ | ||||
|   CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, | ||||
|                   const cudaIpcMemHandle_t *handles, | ||||
|                   const std::vector<int64_t> &offsets, int rank, | ||||
|   CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, | ||||
|                   const cudaIpcMemHandle_t* handles, | ||||
|                   const std::vector<int64_t>& offsets, int rank, | ||||
|                   bool full_nvlink = true) | ||||
|       : rank_(rank), | ||||
|         world_size_(offsets.size()), | ||||
|         full_nvlink_(full_nvlink), | ||||
|         self_sg_(meta), | ||||
|         d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)), | ||||
|         d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)), | ||||
|         d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { | ||||
|     for (int i = 0; i < world_size_; i++) { | ||||
|       Signal *rank_sg; | ||||
|       Signal* rank_sg; | ||||
|       if (i != rank_) { | ||||
|         char *handle = open_ipc_handle(&handles[i]); | ||||
|         char* handle = open_ipc_handle(&handles[i]); | ||||
|         handle += offsets[i]; | ||||
|         rank_sg = (Signal *)handle; | ||||
|         rank_sg = (Signal*)handle; | ||||
|       } else { | ||||
|         rank_sg = self_sg_; | ||||
|       } | ||||
| @ -302,13 +299,13 @@ class CustomAllreduce { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   char *open_ipc_handle(const void *ipc_handle) { | ||||
|   char* open_ipc_handle(const void* ipc_handle) { | ||||
|     auto [it, new_handle] = | ||||
|         ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); | ||||
|         ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); | ||||
|     if (new_handle) { | ||||
|       char *ipc_ptr; | ||||
|       CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, | ||||
|                                      *((const cudaIpcMemHandle_t *)ipc_handle), | ||||
|       char* ipc_ptr; | ||||
|       CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, | ||||
|                                      *((const cudaIpcMemHandle_t*)ipc_handle), | ||||
|                                      cudaIpcMemLazyEnablePeerAccess)); | ||||
|       it->second = ipc_ptr; | ||||
|     } | ||||
| @ -323,7 +320,7 @@ class CustomAllreduce { | ||||
|     std::vector<int64_t> offsets(num_buffers); | ||||
|     for (int i = 0; i < num_buffers; i++) { | ||||
|       auto ptr = graph_unreg_buffers_[i]; | ||||
|       void *base_ptr; | ||||
|       void* base_ptr; | ||||
|       // note: must share the base address of each allocation, or we get wrong | ||||
|       // address | ||||
|       if (cuPointerGetAttribute(&base_ptr, | ||||
| @ -331,8 +328,8 @@ class CustomAllreduce { | ||||
|                                 (CUdeviceptr)ptr) != CUDA_SUCCESS) | ||||
|         throw std::runtime_error("failed to get pointer attr"); | ||||
|       CUDACHECK(cudaIpcGetMemHandle( | ||||
|           (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); | ||||
|       offsets[i] = ((char *)ptr) - ((char *)base_ptr); | ||||
|           (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); | ||||
|       offsets[i] = ((char*)ptr) - ((char*)base_ptr); | ||||
|     } | ||||
|     return std::make_pair(handles, offsets); | ||||
|   } | ||||
| @ -344,13 +341,13 @@ class CustomAllreduce { | ||||
|           std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); | ||||
|   } | ||||
|  | ||||
|   void register_buffer(const std::vector<std::string> &handles, | ||||
|                        const std::vector<int64_t> &offsets, void *self) { | ||||
|   void register_buffer(const std::vector<std::string>& handles, | ||||
|                        const std::vector<int64_t>& offsets, void* self) { | ||||
|     check_rank_data_capacity(); | ||||
|     RankData data; | ||||
|     for (int i = 0; i < world_size_; i++) { | ||||
|       if (i != rank_) { | ||||
|         char *handle = open_ipc_handle(handles[i].data()); | ||||
|         char* handle = open_ipc_handle(handles[i].data()); | ||||
|         handle += offsets[i]; | ||||
|         data.ptrs[i] = handle; | ||||
|       } else { | ||||
| @ -371,17 +368,17 @@ class CustomAllreduce { | ||||
|   // got a different address. IPC handles have internal reference counting | ||||
|   // mechanism so overhead should be small. | ||||
|   void register_graph_buffers( | ||||
|       const std::vector<std::string> &handles, | ||||
|       const std::vector<std::vector<int64_t>> &offsets) { | ||||
|       const std::vector<std::string>& handles, | ||||
|       const std::vector<std::vector<int64_t>>& offsets) { | ||||
|     auto num_buffers = graph_unreg_buffers_.size(); | ||||
|     check_rank_data_capacity(num_buffers); | ||||
|     std::vector<RankData> rank_data(num_buffers); | ||||
|     for (int i = 0; i < num_buffers; i++) { | ||||
|       auto self_ptr = graph_unreg_buffers_[i]; | ||||
|       auto &rd = rank_data[i]; | ||||
|       auto& rd = rank_data[i]; | ||||
|       for (int j = 0; j < world_size_; j++) { | ||||
|         if (j != rank_) { | ||||
|           char *handle = | ||||
|           char* handle = | ||||
|               open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); | ||||
|           handle += offsets[j][i]; | ||||
|           rd.ptrs[j] = handle; | ||||
| @ -405,7 +402,7 @@ class CustomAllreduce { | ||||
|    * will cause contention on NVLink bus. | ||||
|    */ | ||||
|   template <typename T> | ||||
|   void allreduce(cudaStream_t stream, T *input, T *output, int size, | ||||
|   void allreduce(cudaStream_t stream, T* input, T* output, int size, | ||||
|                  int threads = 512, int block_limit = 36) { | ||||
|     auto d = packed_t<T>::P::size; | ||||
|     if (size % d != 0) | ||||
| @ -418,7 +415,7 @@ class CustomAllreduce { | ||||
|                                std::to_string(kMaxBlocks) + ". Got " + | ||||
|                                std::to_string(block_limit)); | ||||
|  | ||||
|     RankData *ptrs; | ||||
|     RankData* ptrs; | ||||
|     cudaStreamCaptureStatus status; | ||||
|     CUDACHECK(cudaStreamIsCapturing(stream, &status)); | ||||
|     if (status == cudaStreamCaptureStatusActive) { | ||||
|  | ||||
| @ -48,7 +48,7 @@ __global__ void dummy_kernel() { | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void set_data(T *data, int size, int myRank) { | ||||
| __global__ void set_data(T* data, int size, int myRank) { | ||||
|   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||
|        idx += gridDim.x * blockDim.x) { | ||||
|     data[idx] = myRank * 0.11f; | ||||
| @ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void convert_data(const T *data1, const T *data2, double *fdata1, | ||||
|                              double *fdata2, int size) { | ||||
| __global__ void convert_data(const T* data1, const T* data2, double* fdata1, | ||||
|                              double* fdata2, int size) { | ||||
|   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||
|        idx += gridDim.x * blockDim.x) { | ||||
|     fdata1[idx] = data1[idx]; | ||||
| @ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void init_rand(curandState_t *state, int size, int nRanks) { | ||||
| __global__ void init_rand(curandState_t* state, int size, int nRanks) { | ||||
|   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||
|        idx += gridDim.x * blockDim.x) { | ||||
|     for (int i = 0; i < nRanks; i++) { | ||||
| @ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, | ||||
| __global__ void gen_data(curandState_t* state, T* data, double* ground_truth, | ||||
|                          int myRank, int nRanks, int size) { | ||||
|   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||
|        idx += gridDim.x * blockDim.x) { | ||||
| @ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, | ||||
| void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, | ||||
|          int data_size, bool performance_test) { | ||||
|   T *result; | ||||
|   T* result; | ||||
|   cudaStream_t stream; | ||||
|   CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); | ||||
|   CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); | ||||
| @ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, | ||||
|  | ||||
|   cudaIpcMemHandle_t self_data_handle; | ||||
|   cudaIpcMemHandle_t data_handles[8]; | ||||
|   vllm::Signal *buffer; | ||||
|   T *self_data_copy; | ||||
|   vllm::Signal* buffer; | ||||
|   T* self_data_copy; | ||||
|   /** | ||||
|    * Allocate IPC buffer | ||||
|    * | ||||
| @ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, | ||||
|                          MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), | ||||
|                          MPI_BYTE, MPI_COMM_WORLD)); | ||||
|  | ||||
|   void *rank_data; | ||||
|   void* rank_data; | ||||
|   size_t rank_data_sz = 16 * 1024 * 1024; | ||||
|   CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); | ||||
|   std::vector<int64_t> offsets(nRanks, 0); | ||||
|   vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, | ||||
|                            offsets, myRank); | ||||
|   auto *self_data = | ||||
|       reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) + | ||||
|                             sizeof(vllm::Signal) + data_size * sizeof(T)); | ||||
|   auto* self_data = | ||||
|       reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) + | ||||
|                            sizeof(vllm::Signal) + data_size * sizeof(T)); | ||||
|   // hack buffer registration | ||||
|   { | ||||
|     std::vector<std::string> handles; | ||||
|     handles.reserve(nRanks); | ||||
|     for (int i = 0; i < nRanks; i++) { | ||||
|       char *begin = (char *)&data_handles[i]; | ||||
|       char *end = (char *)&data_handles[i + 1]; | ||||
|       char* begin = (char*)&data_handles[i]; | ||||
|       char* end = (char*)&data_handles[i + 1]; | ||||
|       handles.emplace_back(begin, end); | ||||
|     } | ||||
|     std::vector<int64_t> offsets(nRanks, | ||||
| @ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, | ||||
|     fa.register_buffer(handles, offsets, self_data); | ||||
|   } | ||||
|  | ||||
|   double *ground_truth; | ||||
|   double* ground_truth; | ||||
|   CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); | ||||
|   curandState_t *states; | ||||
|   curandState_t* states; | ||||
|   CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); | ||||
|   init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); | ||||
|   gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, | ||||
| @ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, | ||||
|   CUDACHECK(cudaStreamDestroy(stream)); | ||||
| } | ||||
|  | ||||
| int main(int argc, char **argv) { | ||||
| int main(int argc, char** argv) { | ||||
|   int nRanks, myRank; | ||||
|   MPICHECK(MPI_Init(&argc, &argv)); | ||||
|   MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); | ||||
| @ -296,7 +296,7 @@ int main(int argc, char **argv) { | ||||
|   ncclUniqueId id; | ||||
|   ncclComm_t comm; | ||||
|   if (myRank == 0) ncclGetUniqueId(&id); | ||||
|   MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0, | ||||
|   MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0, | ||||
|                      MPI_COMM_WORLD)); | ||||
|   NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); | ||||
|  | ||||
|  | ||||
| @ -6,32 +6,30 @@ | ||||
|  | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)              \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \ | ||||
| #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)         \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)  \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) | ||||
|  | ||||
| #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \ | ||||
|   AT_DISPATCH_SWITCH(                                             \ | ||||
|     TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) | ||||
| #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ | ||||
|   AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) | ||||
|  | ||||
| #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)     \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)   \ | ||||
| #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)   \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)    \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)     \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) | ||||
|  | ||||
| #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)           \ | ||||
|   AT_DISPATCH_SWITCH(                                                    \ | ||||
|     TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) | ||||
|      | ||||
| #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)             \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)      \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)      \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)     \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)       \ | ||||
| #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ | ||||
|   AT_DISPATCH_SWITCH(TYPE, NAME,                               \ | ||||
|                      VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) | ||||
|  | ||||
| #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)         \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)  \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)  \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)   \ | ||||
|   AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) | ||||
|  | ||||
| #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)             \ | ||||
|   AT_DISPATCH_SWITCH(                                             \ | ||||
|     TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) | ||||
| #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ | ||||
|   AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) | ||||
|  | ||||
| @ -11,26 +11,24 @@ | ||||
|   #include <hip/hip_bf16.h> | ||||
|   #include <hip/hip_fp16.h> | ||||
|  | ||||
|   using __nv_bfloat16 = __hip_bfloat16; | ||||
|   using __nv_bfloat162 = __hip_bfloat162; | ||||
| using __nv_bfloat16 = __hip_bfloat16; | ||||
| using __nv_bfloat162 = __hip_bfloat162; | ||||
| #endif | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| // TODO(woosuk): Further optimize this kernel. | ||||
| template<typename scalar_t> | ||||
| template <typename scalar_t> | ||||
| __global__ void rms_norm_kernel( | ||||
|   scalar_t* __restrict__ out,             // [..., hidden_size] | ||||
|   const scalar_t* __restrict__ input,     // [..., hidden_size] | ||||
|   const scalar_t* __restrict__ weight,    // [hidden_size] | ||||
|   const float epsilon, | ||||
|   const int num_tokens, | ||||
|   const int hidden_size) { | ||||
|     scalar_t* __restrict__ out,           // [..., hidden_size] | ||||
|     const scalar_t* __restrict__ input,   // [..., hidden_size] | ||||
|     const scalar_t* __restrict__ weight,  // [hidden_size] | ||||
|     const float epsilon, const int num_tokens, const int hidden_size) { | ||||
|   __shared__ float s_variance; | ||||
|   float variance = 0.0f; | ||||
|  | ||||
|   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||||
|     const float x = (float) input[blockIdx.x * hidden_size + idx]; | ||||
|     const float x = (float)input[blockIdx.x * hidden_size + idx]; | ||||
|     variance += x * x; | ||||
|   } | ||||
|   variance = blockReduceSum<float>(variance); | ||||
| @ -40,12 +38,12 @@ __global__ void rms_norm_kernel( | ||||
|   __syncthreads(); | ||||
|  | ||||
|   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||||
|     float x = (float) input[blockIdx.x * hidden_size + idx]; | ||||
|     out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; | ||||
|     float x = (float)input[blockIdx.x * hidden_size + idx]; | ||||
|     out[blockIdx.x * hidden_size + idx] = | ||||
|         ((scalar_t)(x * s_variance)) * weight[idx]; | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| /* Converter structs for the conversion from torch types to HIP/CUDA types, | ||||
|    and the associated type conversions within HIP/CUDA. These helpers need | ||||
|    to be implemented for now because the relevant type conversion | ||||
| @ -54,51 +52,68 @@ __global__ void rms_norm_kernel( | ||||
|  | ||||
|    Each struct should have the member static constexpr bool `exists`: | ||||
|    If false, the optimized kernel is not used for the corresponding torch type. | ||||
|    If true, the struct should be fully defined as shown in the examples below.  | ||||
|    If true, the struct should be fully defined as shown in the examples below. | ||||
|  */ | ||||
| template<typename torch_type> | ||||
| struct _typeConvert { static constexpr bool exists = false; }; | ||||
| template <typename torch_type> | ||||
| struct _typeConvert { | ||||
|   static constexpr bool exists = false; | ||||
| }; | ||||
|  | ||||
| #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) | ||||
| // CUDA < 12.0 runs into issues with packed type conversion | ||||
| template<> | ||||
| template <> | ||||
| struct _typeConvert<c10::Half> { | ||||
|   static constexpr bool exists = true; | ||||
|   using hip_type = __half; | ||||
|   using packed_hip_type = __half2; | ||||
|  | ||||
|   __device__ static inline float convert(hip_type x) { return __half2float(x); } | ||||
|   __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } | ||||
|   __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } | ||||
|   __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } | ||||
|   __device__ static inline float2 convert(packed_hip_type x) { | ||||
|     return __half22float2(x); | ||||
|   } | ||||
|   __device__ static inline hip_type convert(float x) { | ||||
|     return __float2half_rn(x); | ||||
|   } | ||||
|   __device__ static inline packed_hip_type convert(float2 x) { | ||||
|     return __float22half2_rn(x); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||||
|   #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||||
| // CUDA_ARCH < 800 does not have BF16 support | ||||
| // TODO: Add in ROCm support once public headers handle bf16 maturely | ||||
| template<> | ||||
| template <> | ||||
| struct _typeConvert<c10::BFloat16> { | ||||
|   static constexpr bool exists = true; | ||||
|   using hip_type = __nv_bfloat16; | ||||
|   using packed_hip_type = __nv_bfloat162; | ||||
|  | ||||
|   __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } | ||||
|   __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } | ||||
|   __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } | ||||
|   __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } | ||||
|   __device__ static inline float convert(hip_type x) { | ||||
|     return __bfloat162float(x); | ||||
|   } | ||||
|   __device__ static inline float2 convert(packed_hip_type x) { | ||||
|     return __bfloat1622float2(x); | ||||
|   } | ||||
|   __device__ static inline hip_type convert(float x) { | ||||
|     return __float2bfloat16(x); | ||||
|   } | ||||
|   __device__ static inline packed_hip_type convert(float2 x) { | ||||
|     return __float22bfloat162_rn(x); | ||||
|   } | ||||
| }; | ||||
| #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||||
| #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) | ||||
|   #endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||||
| #endif    // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= | ||||
|           // 12000)) | ||||
|  | ||||
| /* Vector POD struct to generate vectorized and packed FP16/BF16 ops | ||||
|    for appropriate specializations of fused_add_rms_norm_kernel. | ||||
|    Only functions that are necessary in that kernel are implemented. | ||||
|    Alignment to 16 bytes is required to use 128-bit global memory ops. | ||||
|  */ | ||||
| template<typename scalar_t, int width> | ||||
| template <typename scalar_t, int width> | ||||
| struct alignas(16) _f16Vec { | ||||
|   /* Not theoretically necessary that width is a power of 2 but should  | ||||
|      almost always be the case for optimization purposes */  | ||||
|   /* Not theoretically necessary that width is a power of 2 but should | ||||
|      almost always be the case for optimization purposes */ | ||||
|   static_assert(width > 0 && (width & (width - 1)) == 0, | ||||
|                 "Width is not a positive power of 2!"); | ||||
|   using Converter = _typeConvert<scalar_t>; | ||||
| @ -108,51 +123,49 @@ struct alignas(16) _f16Vec { | ||||
|  | ||||
|   __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) { | ||||
|     if constexpr (width % 2 == 0) { | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; i += 2) { | ||||
|         T2 temp{data[i], data[i+1]}; | ||||
|         temp += T2{other.data[i], other.data[i+1]}; | ||||
|         T2 temp{data[i], data[i + 1]}; | ||||
|         temp += T2{other.data[i], other.data[i + 1]}; | ||||
|         data[i] = temp.x; | ||||
|         data[i+1] = temp.y; | ||||
|         data[i + 1] = temp.y; | ||||
|       } | ||||
|     } else { | ||||
|       #pragma unroll | ||||
|       for (int i = 0; i < width; ++i) | ||||
|         data[i] += other.data[i]; | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; ++i) data[i] += other.data[i]; | ||||
|     } | ||||
|     return *this; | ||||
|   } | ||||
|  | ||||
|   __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) { | ||||
|     if constexpr (width % 2 == 0) { | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; i += 2) { | ||||
|         T2 temp{data[i], data[i+1]}; | ||||
|         temp *= T2{other.data[i], other.data[i+1]}; | ||||
|         T2 temp{data[i], data[i + 1]}; | ||||
|         temp *= T2{other.data[i], other.data[i + 1]}; | ||||
|         data[i] = temp.x; | ||||
|         data[i+1] = temp.y; | ||||
|         data[i + 1] = temp.y; | ||||
|       } | ||||
|     } else { | ||||
|       #pragma unroll | ||||
|       for (int i = 0; i < width; ++i) | ||||
|         data[i] *= other.data[i]; | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; ++i) data[i] *= other.data[i]; | ||||
|     } | ||||
|     return *this; | ||||
|   } | ||||
|  | ||||
|   __device__ _f16Vec& operator*=(const float scale) { | ||||
|     if constexpr (width % 2 == 0) { | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; i += 2) { | ||||
|         float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); | ||||
|         float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); | ||||
|         temp_f.x *= scale; | ||||
|         temp_f.y *= scale; | ||||
|         T2 temp = Converter::convert(temp_f); | ||||
|         data[i] = temp.x; | ||||
|         data[i+1] = temp.y; | ||||
|         data[i + 1] = temp.y; | ||||
|       } | ||||
|     } else { | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; ++i) { | ||||
|         float temp = Converter::convert(data[i]) * scale; | ||||
|         data[i] = Converter::convert(temp); | ||||
| @ -164,13 +177,13 @@ struct alignas(16) _f16Vec { | ||||
|   __device__ float sum_squares() const { | ||||
|     float result = 0.0f; | ||||
|     if constexpr (width % 2 == 0) { | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; i += 2) { | ||||
|         float2 z = Converter::convert(T2{data[i], data[i+1]}); | ||||
|         float2 z = Converter::convert(T2{data[i], data[i + 1]}); | ||||
|         result += z.x * z.x + z.y * z.y; | ||||
|       } | ||||
|     } else { | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < width; ++i) { | ||||
|         float x = Converter::convert(data[i]); | ||||
|         result += x * x; | ||||
| @ -184,15 +197,13 @@ struct alignas(16) _f16Vec { | ||||
|    Additional optimizations we can make in this case are | ||||
|    packed and vectorized operations, which help with the | ||||
|    memory latency bottleneck. */ | ||||
| template<typename scalar_t, int width> | ||||
| __global__ std::enable_if_t< | ||||
|   (width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( | ||||
|   scalar_t* __restrict__ input,           // [..., hidden_size] | ||||
|   scalar_t* __restrict__ residual,        // [..., hidden_size] | ||||
|   const scalar_t* __restrict__ weight,    // [hidden_size] | ||||
|   const float epsilon, | ||||
|   const int num_tokens, | ||||
|   const int hidden_size) { | ||||
| template <typename scalar_t, int width> | ||||
| __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists> | ||||
| fused_add_rms_norm_kernel( | ||||
|     scalar_t* __restrict__ input,         // [..., hidden_size] | ||||
|     scalar_t* __restrict__ residual,      // [..., hidden_size] | ||||
|     const scalar_t* __restrict__ weight,  // [hidden_size] | ||||
|     const float epsilon, const int num_tokens, const int hidden_size) { | ||||
|   // Sanity checks on our vector struct and type-punned pointer arithmetic | ||||
|   static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); | ||||
|   static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); | ||||
| @ -203,9 +214,12 @@ __global__ std::enable_if_t< | ||||
|   /* These and the argument pointers are all declared `restrict` as they are | ||||
|      not aliased in practice. Argument pointers should not be dereferenced | ||||
|      in this kernel as that would be undefined behavior */ | ||||
|   auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); | ||||
|   auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); | ||||
|   auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); | ||||
|   auto* __restrict__ input_v = | ||||
|       reinterpret_cast<_f16Vec<scalar_t, width>*>(input); | ||||
|   auto* __restrict__ residual_v = | ||||
|       reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); | ||||
|   auto* __restrict__ weight_v = | ||||
|       reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); | ||||
|  | ||||
|   for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { | ||||
|     int id = blockIdx.x * vec_hidden_size + idx; | ||||
| @ -215,10 +229,11 @@ __global__ std::enable_if_t< | ||||
|     residual_v[id] = temp; | ||||
|   } | ||||
|   /* Keep the following if-else block in sync with the | ||||
|      calculation of max_block_size in fused_add_rms_norm */  | ||||
|      calculation of max_block_size in fused_add_rms_norm */ | ||||
|   if (num_tokens < 256) { | ||||
|     variance = blockReduceSum<float, 1024>(variance); | ||||
|   } else variance = blockReduceSum<float, 256>(variance); | ||||
|   } else | ||||
|     variance = blockReduceSum<float, 256>(variance); | ||||
|   if (threadIdx.x == 0) { | ||||
|     s_variance = rsqrtf(variance / hidden_size + epsilon); | ||||
|   } | ||||
| @ -233,52 +248,50 @@ __global__ std::enable_if_t< | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| /* Generic fused_add_rms_norm_kernel | ||||
|    The width field is not used here but necessary for other specializations. | ||||
|  */ | ||||
| template<typename scalar_t, int width> | ||||
| __global__ std::enable_if_t< | ||||
|   (width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( | ||||
|   scalar_t* __restrict__ input,           // [..., hidden_size] | ||||
|   scalar_t* __restrict__ residual,        // [..., hidden_size] | ||||
|   const scalar_t* __restrict__ weight,    // [hidden_size] | ||||
|   const float epsilon, | ||||
|   const int num_tokens, | ||||
|   const int hidden_size) { | ||||
| template <typename scalar_t, int width> | ||||
| __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists> | ||||
| fused_add_rms_norm_kernel( | ||||
|     scalar_t* __restrict__ input,         // [..., hidden_size] | ||||
|     scalar_t* __restrict__ residual,      // [..., hidden_size] | ||||
|     const scalar_t* __restrict__ weight,  // [hidden_size] | ||||
|     const float epsilon, const int num_tokens, const int hidden_size) { | ||||
|   __shared__ float s_variance; | ||||
|   float variance = 0.0f; | ||||
|  | ||||
|   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||||
|     scalar_t z = input[blockIdx.x * hidden_size + idx]; | ||||
|     z += residual[blockIdx.x * hidden_size + idx]; | ||||
|     float x = (float) z; | ||||
|     float x = (float)z; | ||||
|     variance += x * x; | ||||
|     residual[blockIdx.x * hidden_size + idx] = z; | ||||
|   } | ||||
|   /* Keep the following if-else block in sync with the | ||||
|      calculation of max_block_size in fused_add_rms_norm */  | ||||
|      calculation of max_block_size in fused_add_rms_norm */ | ||||
|   if (num_tokens < 256) { | ||||
|     variance = blockReduceSum<float, 1024>(variance); | ||||
|   } else variance = blockReduceSum<float, 256>(variance); | ||||
|   } else | ||||
|     variance = blockReduceSum<float, 256>(variance); | ||||
|   if (threadIdx.x == 0) { | ||||
|     s_variance = rsqrtf(variance / hidden_size + epsilon); | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
|   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||||
|     float x = (float) residual[blockIdx.x * hidden_size + idx]; | ||||
|     input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; | ||||
|     float x = (float)residual[blockIdx.x * hidden_size + idx]; | ||||
|     input[blockIdx.x * hidden_size + idx] = | ||||
|         ((scalar_t)(x * s_variance)) * weight[idx]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| void rms_norm( | ||||
|   torch::Tensor& out,      // [..., hidden_size] | ||||
|   torch::Tensor& input,    // [..., hidden_size] | ||||
|   torch::Tensor& weight,   // [hidden_size] | ||||
|   float epsilon) { | ||||
| void rms_norm(torch::Tensor& out,     // [..., hidden_size] | ||||
|               torch::Tensor& input,   // [..., hidden_size] | ||||
|               torch::Tensor& weight,  // [hidden_size] | ||||
|               float epsilon) { | ||||
|   int hidden_size = input.size(-1); | ||||
|   int num_tokens = input.numel() / hidden_size; | ||||
|  | ||||
| @ -286,40 +299,27 @@ void rms_norm( | ||||
|   dim3 block(std::min(hidden_size, 1024)); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|     input.scalar_type(), | ||||
|     "rms_norm_kernel", | ||||
|     [&] { | ||||
|       vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|         out.data_ptr<scalar_t>(), | ||||
|         input.data_ptr<scalar_t>(), | ||||
|         weight.data_ptr<scalar_t>(), | ||||
|         epsilon, | ||||
|         num_tokens, | ||||
|         hidden_size); | ||||
|     }); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { | ||||
|     vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|         out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), | ||||
|         weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| #define LAUNCH_FUSED_ADD_RMS_NORM(width)              \ | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(                       \ | ||||
|     input.scalar_type(),                              \ | ||||
|     "fused_add_rms_norm_kernel",                      \ | ||||
|     [&] {                                             \ | ||||
|       vllm::fused_add_rms_norm_kernel                 \ | ||||
|       <scalar_t, width><<<grid, block, 0, stream>>>(  \ | ||||
|         input.data_ptr<scalar_t>(),                   \ | ||||
|         residual.data_ptr<scalar_t>(),                \ | ||||
|         weight.data_ptr<scalar_t>(),                  \ | ||||
|         epsilon,                                      \ | ||||
|         num_tokens,                                   \ | ||||
|         hidden_size);                                 \ | ||||
|     }); | ||||
| #define LAUNCH_FUSED_ADD_RMS_NORM(width)                                       \ | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(                                                \ | ||||
|       input.scalar_type(), "fused_add_rms_norm_kernel", [&] {                  \ | ||||
|         vllm::fused_add_rms_norm_kernel<scalar_t, width>                       \ | ||||
|             <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),           \ | ||||
|                                          residual.data_ptr<scalar_t>(),        \ | ||||
|                                          weight.data_ptr<scalar_t>(), epsilon, \ | ||||
|                                          num_tokens, hidden_size);             \ | ||||
|       }); | ||||
|  | ||||
| void fused_add_rms_norm( | ||||
|   torch::Tensor& input,    // [..., hidden_size] | ||||
|   torch::Tensor& residual, // [..., hidden_size] | ||||
|   torch::Tensor& weight,   // [hidden_size] | ||||
|   float epsilon) { | ||||
| void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size] | ||||
|                         torch::Tensor& residual,  // [..., hidden_size] | ||||
|                         torch::Tensor& weight,    // [hidden_size] | ||||
|                         float epsilon) { | ||||
|   int hidden_size = input.size(-1); | ||||
|   int num_tokens = input.numel() / hidden_size; | ||||
|  | ||||
| @ -342,8 +342,8 @@ void fused_add_rms_norm( | ||||
|   auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); | ||||
|   auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); | ||||
|   auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); | ||||
|   bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ | ||||
|                           && wt_ptr % 16 == 0; | ||||
|   bool ptrs_are_aligned = | ||||
|       inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; | ||||
|   if (ptrs_are_aligned && hidden_size % 8 == 0) { | ||||
|     LAUNCH_FUSED_ADD_RMS_NORM(8); | ||||
|   } else { | ||||
|  | ||||
| @ -3,5 +3,6 @@ | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
|   m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); | ||||
|   m.def("topk_softmax", &topk_softmax, | ||||
|         "Apply topk softmax to the gating outputs."); | ||||
| } | ||||
|  | ||||
| @ -2,8 +2,6 @@ | ||||
|  | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| void topk_softmax( | ||||
|   torch::Tensor& topk_weights, | ||||
|   torch::Tensor& topk_indices, | ||||
|   torch::Tensor& token_expert_indices, | ||||
|   torch::Tensor& gating_output); | ||||
| void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, | ||||
|                   torch::Tensor& token_expert_indices, | ||||
|                   torch::Tensor& gating_output); | ||||
|  | ||||
| @ -19,15 +19,22 @@ | ||||
| #include <torch/extension.h> | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
| #include "../cuda_compat.h" | ||||
|  | ||||
| #include <cub/cub.cuh> | ||||
| #include <cub/util_type.cuh> | ||||
| #ifndef USE_ROCM | ||||
|     #include <cub/util_type.cuh> | ||||
|     #include <cub/cub.cuh> | ||||
| #else | ||||
|     #include <hipcub/util_type.hpp> | ||||
|     #include <hipcub/hipcub.hpp> | ||||
| #endif | ||||
|  | ||||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||
|  | ||||
| namespace vllm { | ||||
| namespace moe { | ||||
|  | ||||
| static constexpr int WARP_SIZE = 32; | ||||
|  | ||||
| /// Aligned array type | ||||
| template < | ||||
|     typename T, | ||||
| @ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ | ||||
| #pragma unroll | ||||
|     for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) | ||||
|     { | ||||
|         thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); | ||||
|         thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); | ||||
|     } | ||||
|  | ||||
|     // From this point, thread max in all the threads have the max within the row. | ||||
| @ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ | ||||
| #pragma unroll | ||||
|     for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) | ||||
|     { | ||||
|         row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); | ||||
|         row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); | ||||
|     } | ||||
|  | ||||
|     // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables | ||||
| @ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ | ||||
| #pragma unroll | ||||
|         for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) | ||||
|         { | ||||
|             float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); | ||||
|             int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); | ||||
|             float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); | ||||
|             int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); | ||||
|  | ||||
|             // We want lower indices to "win" in every thread so we break ties this way | ||||
|             if (other_max > max_val || (other_max == max_val && other_expert < expert)) | ||||
| @ -383,7 +390,7 @@ struct TopkConstants | ||||
| { | ||||
|     static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); | ||||
|     static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); | ||||
|     static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); | ||||
|     static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); | ||||
|     static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; | ||||
|     static constexpr int THREADS_PER_ROW = EXPERTS / VPT; | ||||
|     static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; | ||||
| @ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f | ||||
| { | ||||
|     static constexpr std::size_t MAX_BYTES_PER_LDG = 16; | ||||
|  | ||||
|     static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); | ||||
|     static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); | ||||
|     using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>; | ||||
|     static constexpr int VPT = Constants::VPT; | ||||
|     static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; | ||||
|  | ||||
| @ -7,119 +7,128 @@ | ||||
| #include "cuda_compat.h" | ||||
| #include "dispatch_utils.h" | ||||
|  | ||||
| #define CEILDIV(x,y) (((x) + (y) - 1) / (y)) | ||||
| #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| namespace { | ||||
| __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { | ||||
|     // don't worry about overflow because num_experts is relatively small | ||||
|     return row * total_col + col; | ||||
| } | ||||
| __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, | ||||
|                                          int32_t col) { | ||||
|   // don't worry about overflow because num_experts is relatively small | ||||
|   return row * total_col + col; | ||||
| } | ||||
| }  // namespace | ||||
|  | ||||
| template <typename scalar_t> | ||||
| __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,  | ||||
|                                 int32_t *sorted_token_ids,  | ||||
|                                 int32_t *expert_ids,  | ||||
|                                 int32_t *total_tokens_post_pad, | ||||
|                                 int32_t num_experts,  | ||||
|                                 int32_t block_size,  | ||||
|                                 size_t numel) { | ||||
|     const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||||
|     const size_t start_idx = threadIdx.x * tokens_per_thread; | ||||
| __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, | ||||
|                                             int32_t* sorted_token_ids, | ||||
|                                             int32_t* expert_ids, | ||||
|                                             int32_t* total_tokens_post_pad, | ||||
|                                             int32_t num_experts, | ||||
|                                             int32_t block_size, size_t numel) { | ||||
|   const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||||
|   const size_t start_idx = threadIdx.x * tokens_per_thread; | ||||
|  | ||||
|     extern __shared__ int32_t shared_mem[]; | ||||
|   extern __shared__ int32_t shared_mem[]; | ||||
|  | ||||
|     int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) | ||||
|     int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) | ||||
|   int32_t* tokens_cnts = | ||||
|       shared_mem;  // 2d tensor with shape (num_experts + 1, num_experts) | ||||
|   int32_t* cumsum = | ||||
|       shared_mem + (num_experts + 1) * | ||||
|                        num_experts;  // 1d tensor with shape (num_experts + 1) | ||||
|  | ||||
|     for (int i = 0; i < num_experts; ++i) { | ||||
|         tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; | ||||
|   for (int i = 0; i < num_experts; ++i) { | ||||
|     tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * In the first step we compute token_cnts[thread_index + 1][expert_index], | ||||
|    * which counts how many tokens in the token shard of thread_index are | ||||
|    * assigned to expert expert_index. | ||||
|    */ | ||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|     ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // For each expert we accumulate the token counts from the different threads. | ||||
|   tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; | ||||
|   for (int i = 1; i <= blockDim.x; ++i) { | ||||
|     tokens_cnts[index(num_experts, i, threadIdx.x)] += | ||||
|         tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // We accumulate the token counts of all experts in thread 0. | ||||
|   if (threadIdx.x == 0) { | ||||
|     cumsum[0] = 0; | ||||
|     for (int i = 1; i <= num_experts; ++i) { | ||||
|       cumsum[i] = cumsum[i - 1] + | ||||
|                   CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], | ||||
|                           block_size) * | ||||
|                       block_size; | ||||
|     } | ||||
|     *total_tokens_post_pad = cumsum[num_experts]; | ||||
|   } | ||||
|  | ||||
|     /** | ||||
|     * In the first step we compute token_cnts[thread_index + 1][expert_index], | ||||
|     * which counts how many tokens in the token shard of thread_index are assigned | ||||
|     * to expert expert_index. | ||||
|     */ | ||||
|     for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|         ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];  | ||||
|     } | ||||
|   __syncthreads(); | ||||
|  | ||||
|     __syncthreads(); | ||||
|   /** | ||||
|    * For each expert, each thread processes the tokens of the corresponding | ||||
|    * blocks and stores the corresponding expert_id for each block. | ||||
|    */ | ||||
|   for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; | ||||
|        i += block_size) { | ||||
|     expert_ids[i / block_size] = threadIdx.x; | ||||
|   } | ||||
|  | ||||
|     // For each expert we accumulate the token counts from the different threads. | ||||
|     tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; | ||||
|     for (int i = 1; i <= blockDim.x; ++i) { | ||||
|         tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; | ||||
|     } | ||||
|  | ||||
|     __syncthreads(); | ||||
|      | ||||
|     // We accumulate the token counts of all experts in thread 0. | ||||
|     if (threadIdx.x == 0) { | ||||
|         cumsum[0] = 0; | ||||
|         for (int i = 1; i <= num_experts; ++i) { | ||||
|             cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; | ||||
|         } | ||||
|         *total_tokens_post_pad = cumsum[num_experts]; | ||||
|     } | ||||
|  | ||||
|     __syncthreads(); | ||||
|  | ||||
|     /** | ||||
|     * For each expert, each thread processes the tokens of the corresponding blocks | ||||
|     * and stores the corresponding expert_id for each block. | ||||
|     */ | ||||
|     for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { | ||||
|         expert_ids[i / block_size] = threadIdx.x; | ||||
|     } | ||||
|      | ||||
|     /** | ||||
|     * Each thread processes a token shard, calculating the index of each token after | ||||
|     * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and | ||||
|     * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], | ||||
|     * where * represents a padding value(preset in python). | ||||
|     */ | ||||
|     for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|         int32_t expert_id = topk_ids[i]; | ||||
|         /** The cumsum[expert_id] stores the starting index of the tokens that the | ||||
|         * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] | ||||
|         * stores the indices of the tokens processed by the expert with expert_id within | ||||
|         * the current thread's token shard. | ||||
|         */ | ||||
|         int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; | ||||
|         sorted_token_ids[rank_post_pad] = i; | ||||
|         ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; | ||||
|     } | ||||
| } | ||||
|   /** | ||||
|    * Each thread processes a token shard, calculating the index of each token | ||||
|    * after sorting by expert number. Given the example topk_ids = | ||||
|    * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, | ||||
|    * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a | ||||
|    * padding value(preset in python). | ||||
|    */ | ||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|     int32_t expert_id = topk_ids[i]; | ||||
|     /** The cumsum[expert_id] stores the starting index of the tokens that the | ||||
|      * expert with expert_id needs to process, and | ||||
|      * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens | ||||
|      * processed by the expert with expert_id within the current thread's token | ||||
|      * shard. | ||||
|      */ | ||||
|     int32_t rank_post_pad = | ||||
|         tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + | ||||
|         cumsum[expert_id]; | ||||
|     sorted_token_ids[rank_post_pad] = i; | ||||
|     ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; | ||||
|   } | ||||
| } | ||||
| }  // namespace vllm | ||||
|  | ||||
| void moe_align_block_size( | ||||
|     torch::Tensor topk_ids, | ||||
|     int num_experts, | ||||
|     int block_size, | ||||
|     torch::Tensor sorted_token_ids, | ||||
|     torch::Tensor experts_ids, | ||||
|     torch::Tensor num_tokens_post_pad) { | ||||
|     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|     VLLM_DISPATCH_INTEGRAL_TYPES( | ||||
|         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||||
|         // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors | ||||
|         const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); | ||||
| void moe_align_block_size(torch::Tensor topk_ids, int num_experts, | ||||
|                           int block_size, torch::Tensor sorted_token_ids, | ||||
|                           torch::Tensor experts_ids, | ||||
|                           torch::Tensor num_tokens_post_pad) { | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_INTEGRAL_TYPES( | ||||
|       topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||||
|         // calc needed amount of shared mem for `tokens_cnts` and `cumsum` | ||||
|         // tensors | ||||
|         const int32_t shared_mem = | ||||
|             ((num_experts + 1) * num_experts + (num_experts + 1)) * | ||||
|             sizeof(int32_t); | ||||
|  | ||||
|         // set dynamic shared mem | ||||
|         auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; | ||||
|         AT_CUDA_CHECK( | ||||
|             VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); | ||||
|         AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( | ||||
|             (void*)kernel, shared_mem)); | ||||
|         kernel<<<1, num_experts, shared_mem, stream>>>( | ||||
|             topk_ids.data_ptr<scalar_t>(), | ||||
|             sorted_token_ids.data_ptr<int32_t>(),  | ||||
|             experts_ids.data_ptr<int32_t>(),  | ||||
|             num_tokens_post_pad.data_ptr<int32_t>(),  | ||||
|             num_experts, | ||||
|             block_size, | ||||
|             topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), | ||||
|             experts_ids.data_ptr<int32_t>(), | ||||
|             num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||||
|             topk_ids.numel()); | ||||
|     }); | ||||
|       }); | ||||
| } | ||||
|  | ||||
							
								
								
									
										271
									
								
								csrc/ops.h
									
									
									
									
									
								
							
							
						
						
									
										271
									
								
								csrc/ops.h
									
									
									
									
									
								
							| @ -3,204 +3,139 @@ | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| void paged_attention_v1( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& query, | ||||
|   torch::Tensor& key_cache, | ||||
|   torch::Tensor& value_cache, | ||||
|   int num_kv_heads, | ||||
|   float scale, | ||||
|   torch::Tensor& block_tables, | ||||
|   torch::Tensor& seq_lens, | ||||
|   int block_size, | ||||
|   int max_seq_len, | ||||
|   const c10::optional<torch::Tensor>& alibi_slopes, | ||||
|   const std::string& kv_cache_dtype, | ||||
|   float kv_scale); | ||||
|     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, | ||||
|     torch::Tensor& value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, | ||||
|     int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, | ||||
|     const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, | ||||
|     const int blocksparse_local_blocks, const int blocksparse_vert_stride, | ||||
|     const int blocksparse_block_size, const int blocksparse_head_sliding_step); | ||||
|  | ||||
| void paged_attention_v2( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& exp_sums, | ||||
|   torch::Tensor& max_logits, | ||||
|   torch::Tensor& tmp_out, | ||||
|   torch::Tensor& query, | ||||
|   torch::Tensor& key_cache, | ||||
|   torch::Tensor& value_cache, | ||||
|   int num_kv_heads, | ||||
|   float scale, | ||||
|   torch::Tensor& block_tables, | ||||
|   torch::Tensor& seq_lens, | ||||
|   int block_size, | ||||
|   int max_seq_len, | ||||
|   const c10::optional<torch::Tensor>& alibi_slopes, | ||||
|   const std::string& kv_cache_dtype, | ||||
|   float kv_scale); | ||||
|     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, | ||||
|     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, | ||||
|     torch::Tensor& value_cache, int num_kv_heads, float scale, | ||||
|     torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, | ||||
|     int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, | ||||
|     const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, | ||||
|     const int blocksparse_local_blocks, const int blocksparse_vert_stride, | ||||
|     const int blocksparse_block_size, const int blocksparse_head_sliding_step); | ||||
|  | ||||
| void rms_norm( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input, | ||||
|   torch::Tensor& weight, | ||||
|   float epsilon); | ||||
| void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, | ||||
|               float epsilon); | ||||
|  | ||||
| void fused_add_rms_norm( | ||||
|   torch::Tensor& input, | ||||
|   torch::Tensor& residual, | ||||
|   torch::Tensor& weight, | ||||
|   float epsilon); | ||||
| void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, | ||||
|                         torch::Tensor& weight, float epsilon); | ||||
|  | ||||
| void rotary_embedding( | ||||
|   torch::Tensor& positions, | ||||
|   torch::Tensor& query, | ||||
|   torch::Tensor& key, | ||||
|   int head_size, | ||||
|   torch::Tensor& cos_sin_cache, | ||||
|   bool is_neox); | ||||
| void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, | ||||
|                       torch::Tensor& key, int head_size, | ||||
|                       torch::Tensor& cos_sin_cache, bool is_neox); | ||||
|  | ||||
| void batched_rotary_embedding( | ||||
|   torch::Tensor& positions, | ||||
|   torch::Tensor& query, | ||||
|   torch::Tensor& key, | ||||
|   int head_size, | ||||
|   torch::Tensor& cos_sin_cache, | ||||
|   bool is_neox, | ||||
|   int rot_dim, | ||||
|   torch::Tensor& cos_sin_cache_offsets); | ||||
| void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, | ||||
|                               torch::Tensor& key, int head_size, | ||||
|                               torch::Tensor& cos_sin_cache, bool is_neox, | ||||
|                               int rot_dim, | ||||
|                               torch::Tensor& cos_sin_cache_offsets); | ||||
|  | ||||
| void silu_and_mul( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input); | ||||
| void silu_and_mul(torch::Tensor& out, torch::Tensor& input); | ||||
|  | ||||
| void gelu_and_mul( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input); | ||||
| void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); | ||||
|  | ||||
| void gelu_tanh_and_mul( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input); | ||||
| void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); | ||||
|  | ||||
| void gelu_new( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input); | ||||
| void gelu_new(torch::Tensor& out, torch::Tensor& input); | ||||
|  | ||||
| void gelu_fast( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input); | ||||
| void gelu_fast(torch::Tensor& out, torch::Tensor& input); | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
| torch::Tensor aqlm_gemm( | ||||
|   const torch::Tensor& input, | ||||
|   const torch::Tensor& codes, | ||||
|   const torch::Tensor& codebooks, | ||||
|   const torch::Tensor& scales, | ||||
|   const torch::Tensor& codebook_partition_sizes, | ||||
|   const std::optional<torch::Tensor>& bias | ||||
| ); | ||||
| torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, | ||||
|                         const torch::Tensor& codebooks, | ||||
|                         const torch::Tensor& scales, | ||||
|                         const torch::Tensor& codebook_partition_sizes, | ||||
|                         const std::optional<torch::Tensor>& bias); | ||||
|  | ||||
| torch::Tensor aqlm_dequant( | ||||
|   const torch::Tensor& codes, | ||||
|   const torch::Tensor& codebooks, | ||||
|   const torch::Tensor& codebook_partition_sizes | ||||
| ); | ||||
| torch::Tensor aqlm_dequant(const torch::Tensor& codes, | ||||
|                            const torch::Tensor& codebooks, | ||||
|                            const torch::Tensor& codebook_partition_sizes); | ||||
|  | ||||
| torch::Tensor awq_gemm( | ||||
|   torch::Tensor _in_feats, | ||||
|   torch::Tensor _kernel, | ||||
|   torch::Tensor _scaling_factors, | ||||
|   torch::Tensor _zeros, | ||||
|   int split_k_iters); | ||||
| torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, | ||||
|                        torch::Tensor _scaling_factors, torch::Tensor _zeros, | ||||
|                        int split_k_iters); | ||||
|  | ||||
| torch::Tensor awq_dequantize( | ||||
|     torch::Tensor _kernel, | ||||
|     torch::Tensor _scaling_factors, | ||||
|     torch::Tensor _zeros, | ||||
|     int split_k_iters, | ||||
|     int thx, | ||||
|     int thy); | ||||
| torch::Tensor awq_dequantize(torch::Tensor _kernel, | ||||
|                              torch::Tensor _scaling_factors, | ||||
|                              torch::Tensor _zeros, int split_k_iters, int thx, | ||||
|                              int thy); | ||||
|  | ||||
| torch::Tensor marlin_gemm( | ||||
|     torch::Tensor& a,  | ||||
|     torch::Tensor& b_q_weight, | ||||
|     torch::Tensor& b_scales,  | ||||
|     torch::Tensor& workspace, | ||||
|     int64_t size_m,  | ||||
|     int64_t size_n,  | ||||
|     int64_t size_k); | ||||
| torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, | ||||
|                           torch::Tensor& b_scales, torch::Tensor& workspace, | ||||
|                           int64_t size_m, int64_t size_n, int64_t size_k); | ||||
|  | ||||
| torch::Tensor gptq_marlin_gemm( | ||||
|   torch::Tensor &a, | ||||
|   torch::Tensor &b_q_weight, | ||||
|   torch::Tensor &b_scales, | ||||
|   torch::Tensor &g_idx, | ||||
|   torch::Tensor &perm, | ||||
|   torch::Tensor &workspace, | ||||
|   int64_t num_bits, | ||||
|   int64_t size_m, | ||||
|   int64_t size_n, | ||||
|   int64_t size_k, | ||||
|   bool is_k_full); | ||||
| torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, | ||||
|                                   torch::Tensor& b_meta, | ||||
|                                   torch::Tensor& b_scales, | ||||
|                                   torch::Tensor& workspace, int64_t num_bits, | ||||
|                                   int64_t size_m, int64_t size_n, | ||||
|                                   int64_t size_k); | ||||
|  | ||||
| torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, | ||||
|                                torch::Tensor& b_scales, torch::Tensor& g_idx, | ||||
|                                torch::Tensor& perm, torch::Tensor& workspace, | ||||
|                                int64_t num_bits, int64_t size_m, int64_t size_n, | ||||
|                                int64_t size_k, bool is_k_full); | ||||
|  | ||||
| torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, | ||||
|                                  int64_t size_k, int64_t size_n, | ||||
|                                  int64_t num_bits); | ||||
|  | ||||
| int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, | ||||
|                          torch::Tensor const& b, torch::Tensor const& a_scales, | ||||
|                          torch::Tensor const& b_scales); | ||||
|  | ||||
| torch::Tensor gptq_marlin_repack( | ||||
|   torch::Tensor &b_q_weight, | ||||
|   torch::Tensor &perm, | ||||
|   int64_t size_k, | ||||
|   int64_t size_n, | ||||
|   int64_t num_bits); | ||||
| #endif | ||||
|  | ||||
| void squeezellm_gemm( | ||||
|   torch::Tensor vec, | ||||
|   torch::Tensor mat, | ||||
|   torch::Tensor mul, | ||||
|   torch::Tensor lookup_table); | ||||
| void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, | ||||
|                               torch::Tensor const& scale); | ||||
|  | ||||
| torch::Tensor gptq_gemm( | ||||
|   torch::Tensor a, | ||||
|   torch::Tensor b_q_weight, | ||||
|   torch::Tensor b_gptq_qzeros, | ||||
|   torch::Tensor b_gptq_scales, | ||||
|   torch::Tensor b_g_idx, | ||||
|   bool use_exllama, | ||||
|   int bit); | ||||
| void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | ||||
|                      torch::Tensor lookup_table); | ||||
|  | ||||
| void gptq_shuffle( | ||||
|   torch::Tensor q_weight, | ||||
|   torch::Tensor q_perm, | ||||
|   int bit); | ||||
| torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, | ||||
|                         torch::Tensor b_gptq_qzeros, | ||||
|                         torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, | ||||
|                         bool use_exllama, int bit); | ||||
|  | ||||
| void static_scaled_fp8_quant( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input, | ||||
|   torch::Tensor& scale); | ||||
| void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); | ||||
|  | ||||
| void dynamic_scaled_fp8_quant( | ||||
|   torch::Tensor& out, | ||||
|   torch::Tensor& input, | ||||
|   torch::Tensor& scale); | ||||
| void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, | ||||
|                              torch::Tensor& scale); | ||||
|  | ||||
| void moe_align_block_size( | ||||
|   torch::Tensor topk_ids, | ||||
|   int num_experts, | ||||
|   int block_size, | ||||
|   torch::Tensor sorted_token_ids, | ||||
|   torch::Tensor experts_ids, | ||||
|   torch::Tensor num_tokens_post_pad); | ||||
| void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, | ||||
|                               torch::Tensor& scale); | ||||
|  | ||||
| void moe_align_block_size(torch::Tensor topk_ids, int num_experts, | ||||
|                           int block_size, torch::Tensor sorted_token_ids, | ||||
|                           torch::Tensor experts_ids, | ||||
|                           torch::Tensor num_tokens_post_pad); | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
| using fptr_t = uint64_t; | ||||
| fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, | ||||
|                     const std::vector<std::string> &handles, | ||||
|                     const std::vector<int64_t> &offsets, int rank, | ||||
|                     bool full_nvlink); | ||||
| bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, | ||||
| fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, | ||||
|                       const std::vector<std::string>& handles, | ||||
|                       const std::vector<int64_t>& offsets, int rank, | ||||
|                       bool full_nvlink); | ||||
| void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); | ||||
| void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, | ||||
|                       torch::Tensor &out); | ||||
| bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, | ||||
|                       bool full_nvlink); | ||||
| void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); | ||||
| void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, | ||||
|                       torch::Tensor& out); | ||||
| void dispose(fptr_t _fa); | ||||
| int meta_size(); | ||||
| void register_buffer(fptr_t _fa, torch::Tensor &t, | ||||
|                      const std::vector<std::string> &handles, | ||||
|                      const std::vector<int64_t> &offsets); | ||||
| std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); | ||||
| void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, | ||||
|                             const std::vector<std::vector<int64_t>> &offsets); | ||||
| void register_buffer(fptr_t _fa, torch::Tensor& t, | ||||
|                      const std::vector<std::string>& handles, | ||||
|                      const std::vector<int64_t>& offsets); | ||||
| std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( | ||||
|     fptr_t _fa); | ||||
| void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, | ||||
|                             const std::vector<std::vector<int64_t>>& offsets); | ||||
| #endif | ||||
|  | ||||
| @ -7,14 +7,10 @@ | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| template<typename scalar_t, bool IS_NEOX> | ||||
| template <typename scalar_t, bool IS_NEOX> | ||||
| inline __device__ void apply_token_rotary_embedding( | ||||
|   scalar_t* __restrict__ arr, | ||||
|   const scalar_t* __restrict__ cos_ptr, | ||||
|   const scalar_t* __restrict__ sin_ptr, | ||||
|   int rot_offset, | ||||
|   int embed_dim) | ||||
| { | ||||
|     scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, | ||||
|     const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { | ||||
|   int x_index, y_index; | ||||
|   scalar_t cos, sin; | ||||
|   if (IS_NEOX) { | ||||
| @ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( | ||||
|   arr[y_index] = y * cos + x * sin; | ||||
| } | ||||
|  | ||||
| template<typename scalar_t, bool IS_NEOX> | ||||
| template <typename scalar_t, bool IS_NEOX> | ||||
| inline __device__ void apply_rotary_embedding( | ||||
|   scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] | ||||
|   scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] | ||||
|   const scalar_t* cache_ptr, | ||||
|   const int head_size, | ||||
|   const int num_heads, | ||||
|   const int num_kv_heads, | ||||
|   const int rot_dim, | ||||
|   const int token_idx, | ||||
|   const int64_t query_stride, | ||||
|   const int64_t key_stride) | ||||
| { | ||||
|     scalar_t* __restrict__ query,  // [batch_size, seq_len, num_heads, | ||||
|                                    // head_size] or [num_tokens, num_heads, | ||||
|                                    // head_size] | ||||
|     scalar_t* __restrict__ key,    // [batch_size, seq_len, num_kv_heads, | ||||
|                                    // head_size] or [num_tokens, num_kv_heads, | ||||
|                                    // head_size] | ||||
|     const scalar_t* cache_ptr, const int head_size, const int num_heads, | ||||
|     const int num_kv_heads, const int rot_dim, const int token_idx, | ||||
|     const int64_t query_stride, const int64_t key_stride) { | ||||
|   const int embed_dim = rot_dim / 2; | ||||
|   const scalar_t* cos_ptr = cache_ptr; | ||||
|   const scalar_t* sin_ptr = cache_ptr + embed_dim; | ||||
| @ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( | ||||
|     const int head_idx = i / embed_dim; | ||||
|     const int64_t token_head = token_idx * query_stride + head_idx * head_size; | ||||
|     const int rot_offset = i % embed_dim; | ||||
|     apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, | ||||
|                                               sin_ptr, rot_offset, embed_dim); | ||||
|     apply_token_rotary_embedding<scalar_t, IS_NEOX>( | ||||
|         query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); | ||||
|   } | ||||
|  | ||||
|   const int nk = num_kv_heads * embed_dim; | ||||
| @ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( | ||||
|     const int head_idx = i / embed_dim; | ||||
|     const int64_t token_head = token_idx * key_stride + head_idx * head_size; | ||||
|     const int rot_offset = i % embed_dim; | ||||
|     apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, | ||||
|                                               sin_ptr, rot_offset, embed_dim); | ||||
|     apply_token_rotary_embedding<scalar_t, IS_NEOX>( | ||||
|         key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template<typename scalar_t, bool IS_NEOX> | ||||
| template <typename scalar_t, bool IS_NEOX> | ||||
| __global__ void rotary_embedding_kernel( | ||||
|   const int64_t* __restrict__ positions,        // [batch_size, seq_len] or [num_tokens] | ||||
|   scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] | ||||
|   scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] | ||||
|   const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2] | ||||
|   const int rot_dim, | ||||
|   const int64_t query_stride, | ||||
|   const int64_t key_stride, | ||||
|   const int num_heads, | ||||
|   const int num_kv_heads, | ||||
|   const int head_size) { | ||||
|     const int64_t* __restrict__ positions,  // [batch_size, seq_len] or | ||||
|                                             // [num_tokens] | ||||
|     scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads, | ||||
|                                    // head_size] or [num_tokens, num_heads, | ||||
|                                    // head_size] | ||||
|     scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads, | ||||
|                                  // head_size] or [num_tokens, num_kv_heads, | ||||
|                                  // head_size] | ||||
|     const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim // | ||||
|                                                  // 2] | ||||
|     const int rot_dim, const int64_t query_stride, const int64_t key_stride, | ||||
|     const int num_heads, const int num_kv_heads, const int head_size) { | ||||
|   // Each thread block is responsible for one token. | ||||
|   const int token_idx = blockIdx.x; | ||||
|   int64_t pos = positions[token_idx]; | ||||
|   const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; | ||||
|  | ||||
|   apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); | ||||
|   apply_rotary_embedding<scalar_t, IS_NEOX>( | ||||
|       query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, | ||||
|       token_idx, query_stride, key_stride); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t, bool IS_NEOX> | ||||
| template <typename scalar_t, bool IS_NEOX> | ||||
| __global__ void batched_rotary_embedding_kernel( | ||||
|   const int64_t* __restrict__ positions,              // [batch_size, seq_len] or [num_tokens] | ||||
|   scalar_t* __restrict__ query,                       // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] | ||||
|   scalar_t* __restrict__ key,                         // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] | ||||
|   const scalar_t* __restrict__ cos_sin_cache,         // [max_position, 2, rot_dim // 2] | ||||
|   const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len] or [num_tokens] | ||||
|   const int rot_dim, | ||||
|   const int64_t query_stride, | ||||
|   const int64_t key_stride, | ||||
|   const int num_heads, | ||||
|   const int num_kv_heads, | ||||
|   const int head_size) { | ||||
|     const int64_t* __restrict__ positions,  // [batch_size, seq_len] or | ||||
|                                             // [num_tokens] | ||||
|     scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads, | ||||
|                                    // head_size] or [num_tokens, num_heads, | ||||
|                                    // head_size] | ||||
|     scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads, | ||||
|                                  // head_size] or [num_tokens, num_kv_heads, | ||||
|                                  // head_size] | ||||
|     const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim // | ||||
|                                                  // 2] | ||||
|     const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len] | ||||
|                                                         // or [num_tokens] | ||||
|     const int rot_dim, const int64_t query_stride, const int64_t key_stride, | ||||
|     const int num_heads, const int num_kv_heads, const int head_size) { | ||||
|   // Each thread block is responsible for one token. | ||||
|   const int token_idx = blockIdx.x; | ||||
|   int64_t pos = positions[token_idx]; | ||||
|   int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; | ||||
|   const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; | ||||
|   const scalar_t* cache_ptr = | ||||
|       cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; | ||||
|  | ||||
|   apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); | ||||
|   apply_rotary_embedding<scalar_t, IS_NEOX>( | ||||
|       query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, | ||||
|       token_idx, query_stride, key_stride); | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
| }  // namespace vllm | ||||
|  | ||||
| void rotary_embedding( | ||||
|   torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens] | ||||
|   torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] | ||||
|   torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] | ||||
|   int head_size, | ||||
|   torch::Tensor& cos_sin_cache,     // [max_position, rot_dim] | ||||
|   bool is_neox) { | ||||
|     torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens] | ||||
|     torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or | ||||
|                            // [num_tokens, num_heads * head_size] | ||||
|     torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or | ||||
|                            // [num_tokens, num_kv_heads * head_size] | ||||
|     int head_size, | ||||
|     torch::Tensor& cos_sin_cache,  // [max_position, rot_dim] | ||||
|     bool is_neox) { | ||||
|   int64_t num_tokens = query.numel() / query.size(-1); | ||||
|   int rot_dim = cos_sin_cache.size(1); | ||||
|   int num_heads = query.size(-1) / head_size; | ||||
| @ -135,36 +141,21 @@ void rotary_embedding( | ||||
|   dim3 block(std::min(num_heads * rot_dim / 2, 512)); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|     query.scalar_type(), | ||||
|     "rotary_embedding", | ||||
|     [&] { | ||||
|       if (is_neox) { | ||||
|         vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( | ||||
|           positions.data_ptr<int64_t>(), | ||||
|           query.data_ptr<scalar_t>(), | ||||
|           key.data_ptr<scalar_t>(), | ||||
|           cos_sin_cache.data_ptr<scalar_t>(), | ||||
|           rot_dim, | ||||
|           query_stride, | ||||
|           key_stride, | ||||
|           num_heads, | ||||
|           num_kv_heads, | ||||
|           head_size); | ||||
|       } else { | ||||
|         vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( | ||||
|           positions.data_ptr<int64_t>(), | ||||
|           query.data_ptr<scalar_t>(), | ||||
|           key.data_ptr<scalar_t>(), | ||||
|           cos_sin_cache.data_ptr<scalar_t>(), | ||||
|           rot_dim, | ||||
|           query_stride, | ||||
|           key_stride, | ||||
|           num_heads, | ||||
|           num_kv_heads, | ||||
|           head_size); | ||||
|       } | ||||
|     }); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { | ||||
|     if (is_neox) { | ||||
|       vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( | ||||
|           positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), | ||||
|           key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim, | ||||
|           query_stride, key_stride, num_heads, num_kv_heads, head_size); | ||||
|     } else { | ||||
|       vllm::rotary_embedding_kernel<scalar_t, false> | ||||
|           <<<grid, block, 0, stream>>>( | ||||
|               positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), | ||||
|               key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), | ||||
|               rot_dim, query_stride, key_stride, num_heads, num_kv_heads, | ||||
|               head_size); | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| /* | ||||
| @ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together | ||||
| and process in batched manner. | ||||
| */ | ||||
| void batched_rotary_embedding( | ||||
|   torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens] | ||||
|   torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] | ||||
|   torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] | ||||
|   int head_size, | ||||
|   torch::Tensor& cos_sin_cache,     // [max_position, rot_dim] | ||||
|   bool is_neox, | ||||
|   int rot_dim, | ||||
|   torch::Tensor& cos_sin_cache_offsets // [num_tokens] | ||||
|     torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens] | ||||
|     torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or | ||||
|                            // [num_tokens, num_heads * head_size] | ||||
|     torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or | ||||
|                            // [num_tokens, num_kv_heads * head_size] | ||||
|     int head_size, | ||||
|     torch::Tensor& cos_sin_cache,  // [max_position, rot_dim] | ||||
|     bool is_neox, int rot_dim, | ||||
|     torch::Tensor& cos_sin_cache_offsets  // [num_tokens] | ||||
| ) { | ||||
|   int64_t num_tokens = cos_sin_cache_offsets.size(0); | ||||
|   int num_heads = query.size(-1) / head_size; | ||||
| @ -191,36 +183,21 @@ void batched_rotary_embedding( | ||||
|   dim3 block(std::min(num_heads * rot_dim / 2, 512)); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|     query.scalar_type(), | ||||
|     "rotary_embedding", | ||||
|     [&] { | ||||
|       if (is_neox) { | ||||
|         vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( | ||||
|           positions.data_ptr<int64_t>(), | ||||
|           query.data_ptr<scalar_t>(), | ||||
|           key.data_ptr<scalar_t>(), | ||||
|           cos_sin_cache.data_ptr<scalar_t>(), | ||||
|           cos_sin_cache_offsets.data_ptr<int64_t>(), | ||||
|           rot_dim, | ||||
|           query_stride, | ||||
|           key_stride, | ||||
|           num_heads, | ||||
|           num_kv_heads, | ||||
|           head_size); | ||||
|       } else { | ||||
|         vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( | ||||
|           positions.data_ptr<int64_t>(), | ||||
|           query.data_ptr<scalar_t>(), | ||||
|           key.data_ptr<scalar_t>(), | ||||
|           cos_sin_cache.data_ptr<scalar_t>(), | ||||
|           cos_sin_cache_offsets.data_ptr<int64_t>(), | ||||
|           rot_dim, | ||||
|           query_stride, | ||||
|           key_stride, | ||||
|           num_heads, | ||||
|           num_kv_heads, | ||||
|           head_size); | ||||
|       } | ||||
|     }); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { | ||||
|     if (is_neox) { | ||||
|       vllm::batched_rotary_embedding_kernel<scalar_t, true> | ||||
|           <<<grid, block, 0, stream>>>( | ||||
|               positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), | ||||
|               key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), | ||||
|               cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride, | ||||
|               key_stride, num_heads, num_kv_heads, head_size); | ||||
|     } else { | ||||
|       vllm::batched_rotary_embedding_kernel<scalar_t, false> | ||||
|           <<<grid, block, 0, stream>>>( | ||||
|               positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), | ||||
|               key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), | ||||
|               cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride, | ||||
|               key_stride, num_heads, num_kv_heads, head_size); | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| @ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|     f(in_T, out_T, W_T, narrow, 2752) \ | ||||
|     f(in_T, out_T, W_T, narrow, 2816) \ | ||||
|     f(in_T, out_T, W_T, narrow, 3072) \ | ||||
|     f(in_T, out_T, W_T, narrow, 3328) \ | ||||
|     f(in_T, out_T, W_T, narrow, 3456) \ | ||||
|     f(in_T, out_T, W_T, narrow, 3584) \ | ||||
|     f(in_T, out_T, W_T, narrow, 4096) \ | ||||
| @ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|     f(in_T, out_T, W_T, narrow, 5504) \ | ||||
|     f(in_T, out_T, W_T, narrow, 5632) \ | ||||
|     f(in_T, out_T, W_T, narrow, 6144) \ | ||||
|     f(in_T, out_T, W_T, narrow, 6400) \ | ||||
|     f(in_T, out_T, W_T, narrow, 6848) \ | ||||
|     f(in_T, out_T, W_T, narrow, 6912) \ | ||||
|     f(in_T, out_T, W_T, narrow, 7168) \ | ||||
| @ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|     f(in_T, out_T, W_T, narrow, 22016) \ | ||||
|     f(in_T, out_T, W_T, narrow, 24576) \ | ||||
|     f(in_T, out_T, W_T, narrow, 27392) \ | ||||
|     f(in_T, out_T, W_T, narrow, 27648) \ | ||||
|     f(in_T, out_T, W_T, narrow, 28672) \ | ||||
|     f(in_T, out_T, W_T, narrow, 32000) \ | ||||
|     f(in_T, out_T, W_T, narrow, 32256) \ | ||||
| @ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|     f(in_T, out_T, W_T, 2752, narrow) \ | ||||
|     f(in_T, out_T, W_T, 2816, narrow) \ | ||||
|     f(in_T, out_T, W_T, 3072, narrow) \ | ||||
|     f(in_T, out_T, W_T, 3328, narrow) \ | ||||
|     f(in_T, out_T, W_T, 3456, narrow) \ | ||||
|     f(in_T, out_T, W_T, 3584, narrow) \ | ||||
|     f(in_T, out_T, W_T, 4096, narrow) \ | ||||
| @ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|     f(in_T, out_T, W_T, 5504, narrow) \ | ||||
|     f(in_T, out_T, W_T, 5632, narrow) \ | ||||
|     f(in_T, out_T, W_T, 6144, narrow) \ | ||||
|     f(in_T, out_T, W_T, 6400, narrow) \ | ||||
|     f(in_T, out_T, W_T, 6848, narrow) \ | ||||
|     f(in_T, out_T, W_T, 6912, narrow) \ | ||||
|     f(in_T, out_T, W_T, 7168, narrow) \ | ||||
| @ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|     f(in_T, out_T, W_T, 22016, narrow) \ | ||||
|     f(in_T, out_T, W_T, 24576, narrow) \ | ||||
|     f(in_T, out_T, W_T, 27392, narrow) \ | ||||
|     f(in_T, out_T, W_T, 27648, narrow) \ | ||||
|     f(in_T, out_T, W_T, 28672, narrow) \ | ||||
|     f(in_T, out_T, W_T, 32000, narrow) \ | ||||
|     f(in_T, out_T, W_T, 32256, narrow) \ | ||||
|  | ||||
| @ -1,8 +1,14 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
| #ifndef USE_ROCM | ||||
| #include <cooperative_groups.h> | ||||
| #else | ||||
| #include <hip/hip_cooperative_groups.h> | ||||
| #endif | ||||
| #ifndef USE_ROCM | ||||
| #include <cuda/pipeline> | ||||
| #endif | ||||
| #include <cuda_runtime.h> | ||||
| #include <iostream> | ||||
| #include <stdio.h> | ||||
| @ -11,6 +17,24 @@ | ||||
|  | ||||
| namespace cg = cooperative_groups; | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| template <size_t len> | ||||
| __host__ __device__ | ||||
| inline void* memcpy_blocking(void *dst, const void *src) { | ||||
|   // Does not handle the case of long datatypes | ||||
|   char *d = reinterpret_cast<char *>(dst); | ||||
|   const char *s = reinterpret_cast<const char *>(src); | ||||
|   size_t i = 0; | ||||
| #pragma unroll | ||||
|   for (i = 0; i < len; ++i) { | ||||
|     d[i] = s[i]; | ||||
|   } | ||||
|   return dst; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|  | ||||
| // nthrs = (32, 4) | ||||
| template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size, | ||||
|           size_t W_copy_size, int tx, int ty, int tz, typename in_T, | ||||
| @ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|   } | ||||
| } | ||||
|  | ||||
| #else | ||||
|  | ||||
| template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size, | ||||
|           size_t W_copy_size, int tx, int ty, int tz, typename in_T, | ||||
|           typename out_T, typename W_T> | ||||
| __global__ void | ||||
| bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|                    const W_T *__restrict__ W, | ||||
|                    const int64_t *__restrict__ indicies, int64_t y_offset, | ||||
|                    int64_t full_y_size, int64_t num_layers, int64_t layer_idx, | ||||
|                    float scale) { | ||||
|   size_t batch_idx = blockIdx.y; | ||||
|   int64_t idx = indicies[batch_idx] * num_layers + layer_idx; | ||||
|   if (idx < 0) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   size_t j = blockIdx.x; | ||||
|   constexpr size_t tile_size = tx * ty * vec_size; | ||||
|   constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; | ||||
|   __shared__ float y_warpwise[ty]; | ||||
|  | ||||
|   float y = 0; | ||||
|   vec_t<in_T, vec_size> x_vec; | ||||
|   vec_t<W_T, vec_size> w_vec; | ||||
|   size_t tile_idx; | ||||
|  | ||||
| #pragma unroll | ||||
|   for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { | ||||
|     if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { | ||||
|       x_vec.load(X + (batch_idx * feat_in) + | ||||
|                      tile_idx * tile_size + | ||||
|                      (threadIdx.y * tx + threadIdx.x) * vec_size); | ||||
|       w_vec.load(W + (idx * feat_out + j) * feat_in + | ||||
|                      tile_idx * tile_size + | ||||
|                      (threadIdx.y * tx + threadIdx.x) * vec_size); | ||||
|     } | ||||
|  | ||||
|     float sum = 0.f; | ||||
| #pragma unroll | ||||
|     for (size_t i = 0; i < vec_size; ++i) { | ||||
|       sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale; | ||||
|     } | ||||
| #pragma unroll | ||||
|     for (size_t offset = tx / 2; offset > 0; offset /= 2) { | ||||
|       sum += VLLM_SHFL_DOWN_SYNC(sum, offset); | ||||
|     } | ||||
|  | ||||
|     __syncthreads(); | ||||
|  | ||||
|     if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { | ||||
|       y += sum; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (threadIdx.x == 0) { | ||||
|     y_warpwise[threadIdx.y] = y; | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
|   float y_write = 0.f; | ||||
| #pragma unroll | ||||
|   for (size_t i = 0; i < ty; ++i) { | ||||
|     y_write += y_warpwise[i]; | ||||
|   } | ||||
|   | ||||
|   // write Y; | ||||
|   if (threadIdx.x == 0 && threadIdx.y == 0) { | ||||
|     size_t y_idx = batch_idx * full_y_size + y_offset + j; | ||||
|     Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| // nthrs = (2, 16, 4) | ||||
| template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz, | ||||
|           typename in_T, typename out_T, typename W_T> | ||||
| @ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|   float sum = 0.f; | ||||
| #pragma unroll | ||||
|   for (size_t i = 0; i < vec_size; ++i) { | ||||
| #ifndef USE_ROCM | ||||
|     sum += float(w_vec[i]) * float(x_vec[i]) * scale; | ||||
| #else | ||||
|     sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale; | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   cg::thread_block_tile g = cg::tiled_partition<tx>(block); | ||||
| @ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|   sum = g.shfl(sum, 0); | ||||
|  | ||||
|   if (threadIdx.x == 0) { | ||||
| #ifndef USE_ROCM | ||||
|     Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + | ||||
|       threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum); | ||||
| #else | ||||
|     size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + | ||||
|                    threadIdx.z * ty + threadIdx.y; | ||||
|     Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum)); | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
| @ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|                                         scale); | ||||
|     } | ||||
|   } else { | ||||
| #ifndef USE_ROCM | ||||
|     static_assert(feat_in % (vec_size * 32) == 0 || | ||||
|                   feat_in % (vec_size * 16) == 0 || | ||||
|                   feat_in % (vec_size * 8) == 0); | ||||
| @ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||
|                                         full_y_size, num_layers, layer_idx, | ||||
|                                         scale); | ||||
|     } | ||||
| #else | ||||
|     constexpr size_t rocm_warp_size = warpSize; | ||||
|  | ||||
| #define CHECK_INPUT_TILEABLE_BY(vec_size_) \ | ||||
|     feat_in % (rocm_warp_size * vec_size_) == 0 | ||||
|  | ||||
| #define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_)       \ | ||||
|     if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) {                       \ | ||||
|       constexpr size_t vec_size_shrink = vec_size_;                         \ | ||||
|       constexpr int tx = tx_;                                               \ | ||||
|       constexpr int ty = ty_;                                               \ | ||||
|       dim3 nblks(feat_out, batch_size);                                     \ | ||||
|       dim3 nthrs(tx, ty);                                                   \ | ||||
|       bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink,                \ | ||||
|                           vec_size_shrink * sizeof(in_T),                   \ | ||||
|                           vec_size_shrink * sizeof(W_T),                    \ | ||||
|                           tx, ty, tz>                                       \ | ||||
|           <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,        \ | ||||
|                                         full_y_size, num_layers, layer_idx, \ | ||||
|                                         scale);                             \ | ||||
|     } | ||||
|  | ||||
|     static_assert(CHECK_INPUT_TILEABLE_BY(32) || | ||||
|                   CHECK_INPUT_TILEABLE_BY(16) || | ||||
|                   CHECK_INPUT_TILEABLE_BY( 8) || | ||||
|                   CHECK_INPUT_TILEABLE_BY( 4) || | ||||
|                   CHECK_INPUT_TILEABLE_BY( 2) || | ||||
|                   CHECK_INPUT_TILEABLE_BY( 1)); | ||||
|      | ||||
|     LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) | ||||
|     else | ||||
|     LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) | ||||
|     else | ||||
|     LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size,  8/vec_size) | ||||
|     else | ||||
|     LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) | ||||
|     else | ||||
|     LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) | ||||
|     else | ||||
|     LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) | ||||
|  | ||||
| #undef CHECK_INPUT_TILEABLE_BY | ||||
| #undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM | ||||
| #endif | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -1,8 +1,6 @@ | ||||
| #ifndef VEC_DTYPES_CUH_ | ||||
| #define VEC_DTYPES_CUH_ | ||||
|  | ||||
| #include <cuda_bf16.h> | ||||
| #include <cuda_fp16.h> | ||||
| #ifdef FLASHINFER_USE_FP8 | ||||
| #include <cuda_fp8.h> | ||||
| #endif | ||||
| @ -10,6 +8,9 @@ | ||||
|  | ||||
| #include <type_traits> | ||||
|  | ||||
| #include "../type_convert.h" | ||||
| #include "../../cuda_compat.h" | ||||
|  | ||||
| #define FLASHINFER_INLINE \ | ||||
|   inline __attribute__((always_inline)) __device__ __host__ | ||||
|  | ||||
|  | ||||
| @ -1,12 +1,11 @@ | ||||
| #include <cuda_bf16.h> | ||||
| #include <cuda_fp16.h> | ||||
| #include <torch/extension.h> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
| #include <cstdint> | ||||
| 
 | ||||
| #include "type_convert.h" | ||||
| #include "../cuda_compat.h" | ||||
| #include "bgmv/bgmv_config.h" | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| //====== utils ====== | ||||
| 
 | ||||
| @ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, | ||||
|   TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, | ||||
|               " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); | ||||
| } | ||||
| 
 | ||||
| } // namespace | ||||
| 
 | ||||
| //====== pybind ====== | ||||
| 
 | ||||
| #define DEFINE_pybind(name) m.def(#name, &name, #name); | ||||
| 
 | ||||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
|   m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); | ||||
|   m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, | ||||
|         "dispatch_bgmv_low_level"); | ||||
| } | ||||
							
								
								
									
										11
									
								
								csrc/punica/punica_ops.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								csrc/punica/punica_ops.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, | ||||
|                    torch::Tensor indicies, int64_t layer_idx, float scale); | ||||
|  | ||||
| void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, | ||||
|                              torch::Tensor indicies, int64_t layer_idx, | ||||
|                              float scale, int64_t h_in, int64_t h_out, | ||||
|                              int64_t y_offset); | ||||
							
								
								
									
										13
									
								
								csrc/punica/punica_pybind.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								csrc/punica/punica_pybind.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| #include "punica_ops.h" | ||||
|  | ||||
| //====== pybind ====== | ||||
|  | ||||
| #define DEFINE_pybind(name) m.def(#name, &name, #name); | ||||
|  | ||||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
|   m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); | ||||
|   m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, | ||||
|         "dispatch_bgmv_low_level"); | ||||
| } | ||||
							
								
								
									
										82
									
								
								csrc/punica/type_convert.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								csrc/punica/type_convert.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,82 @@ | ||||
| #ifndef CSRC__PUNICA__TYPE_CONVERT_H__ | ||||
| #define CSRC__PUNICA__TYPE_CONVERT_H__ | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|  | ||||
| #include <cuda_bf16.h> | ||||
| #include <cuda_fp16.h> | ||||
|  | ||||
| #else | ||||
|  | ||||
| #include <hip/hip_bf16.h> | ||||
| #include <hip/hip_fp16.h> | ||||
|  | ||||
| #define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ | ||||
|  | ||||
| typedef __half nv_half; | ||||
| typedef __hip_bfloat16 nv_bfloat16; | ||||
| typedef __hip_bfloat162 nv_bfloat162; | ||||
|  | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { | ||||
|   return __hip_bfloat162{val, val}; | ||||
| } | ||||
|  | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { | ||||
|   return __hip_bfloat162{vall, valr}; | ||||
| } | ||||
|  | ||||
| template <typename T_src, typename T_dst> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline T_dst convert_type(T_src val) { | ||||
|   return static_cast<T_dst>(val); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline float convert_type<__half, float>(__half val) { | ||||
|   return __half2float(val); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline __half convert_type<float, __half>(float val) { | ||||
|   return __float2half(val); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { | ||||
|   return __bfloat162float(val); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) { | ||||
|   return __float2bfloat16(val); | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline T vllm_add(T a, T b) { | ||||
|   return a + b; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline __half vllm_add<__half>(__half a, __half b) { | ||||
|   return __hadd(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __TYPE_CONVERT__HOST_DEVICE__ | ||||
| inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { | ||||
|   return __hadd(a, b); | ||||
| } | ||||
|  | ||||
| #undef __TYPE_CONVERT__HOST_DEVICE__ | ||||
|  | ||||
| #endif // USE_ROCM | ||||
|  | ||||
| #endif // CSRC__PUNICA__TYPE_CONVERT_H__ | ||||
							
								
								
									
										143
									
								
								csrc/pybind.cpp
									
									
									
									
									
								
							
							
						
						
									
										143
									
								
								csrc/pybind.cpp
									
									
									
									
									
								
							| @ -8,114 +8,90 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
|   pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); | ||||
|  | ||||
|   // Attention ops | ||||
|   ops.def( | ||||
|     "paged_attention_v1", | ||||
|     &paged_attention_v1, | ||||
|     "Compute the attention between an input query and the cached keys/values using PagedAttention."); | ||||
|   ops.def( | ||||
|     "paged_attention_v2", | ||||
|     &paged_attention_v2, | ||||
|     "PagedAttention V2."); | ||||
|   ops.def("paged_attention_v1", &paged_attention_v1, | ||||
|           "Compute the attention between an input query and the cached " | ||||
|           "keys/values using PagedAttention."); | ||||
|   ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); | ||||
|  | ||||
|   // Activation ops | ||||
|   ops.def( | ||||
|     "silu_and_mul", | ||||
|     &silu_and_mul, | ||||
|     "Activation function used in SwiGLU."); | ||||
|   ops.def( | ||||
|     "gelu_and_mul", | ||||
|     &gelu_and_mul, | ||||
|     "Activation function used in GeGLU with `none` approximation."); | ||||
|   ops.def( | ||||
|     "gelu_tanh_and_mul", | ||||
|     &gelu_tanh_and_mul, | ||||
|     "Activation function used in GeGLU with `tanh` approximation."); | ||||
|   ops.def( | ||||
|     "gelu_new", | ||||
|     &gelu_new, | ||||
|     "GELU implementation used in GPT-2."); | ||||
|   ops.def( | ||||
|     "gelu_fast", | ||||
|     &gelu_fast, | ||||
|     "Approximate GELU implementation."); | ||||
|   ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); | ||||
|   ops.def("gelu_and_mul", &gelu_and_mul, | ||||
|           "Activation function used in GeGLU with `none` approximation."); | ||||
|   ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, | ||||
|           "Activation function used in GeGLU with `tanh` approximation."); | ||||
|   ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); | ||||
|   ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); | ||||
|  | ||||
|   // Layernorm | ||||
|   ops.def( | ||||
|     "rms_norm", | ||||
|     &rms_norm, | ||||
|     "Apply Root Mean Square (RMS) Normalization to the input tensor."); | ||||
|   ops.def("rms_norm", &rms_norm, | ||||
|           "Apply Root Mean Square (RMS) Normalization to the input tensor."); | ||||
|  | ||||
|   ops.def( | ||||
|     "fused_add_rms_norm", | ||||
|     &fused_add_rms_norm, | ||||
|     "In-place fused Add and RMS Normalization"); | ||||
|   ops.def("fused_add_rms_norm", &fused_add_rms_norm, | ||||
|           "In-place fused Add and RMS Normalization"); | ||||
|  | ||||
|   // Rotary embedding | ||||
|   ops.def( | ||||
|     "rotary_embedding", | ||||
|     &rotary_embedding, | ||||
|     "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); | ||||
|   ops.def("rotary_embedding", &rotary_embedding, | ||||
|           "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); | ||||
|  | ||||
|   ops.def( | ||||
|     "batched_rotary_embedding", | ||||
|     &batched_rotary_embedding, | ||||
|     "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); | ||||
|   ops.def("batched_rotary_embedding", &batched_rotary_embedding, | ||||
|           "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " | ||||
|           "(supports multiple loras)"); | ||||
|  | ||||
| // Quantization ops | ||||
| #ifndef USE_ROCM | ||||
|   ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); | ||||
|   ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); | ||||
|   ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); | ||||
|   ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); | ||||
|   ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); | ||||
|   ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); | ||||
|   ops.def("marlin_gemm", &marlin_gemm, | ||||
|           "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); | ||||
|   ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, | ||||
|           "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); | ||||
|   ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, | ||||
|           "gptq_marlin Optimized Quantized GEMM for GPTQ"); | ||||
|   ops.def("gptq_marlin_repack", &gptq_marlin_repack, | ||||
|           "gptq_marlin repack from GPTQ"); | ||||
|   ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); | ||||
|   ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, | ||||
|           "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " | ||||
|           "per-row/column quantization."); | ||||
| #endif | ||||
|   | ||||
|  | ||||
|   ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); | ||||
|   ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); | ||||
|   ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); | ||||
|   ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); | ||||
|   ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); | ||||
|   ops.def( | ||||
|     "moe_align_block_size", | ||||
|     &moe_align_block_size, | ||||
|     "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); | ||||
|   ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, | ||||
|           "Compute FP8 quantized tensor for given scaling factor"); | ||||
|   ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, | ||||
|           "Compute FP8 quantized tensor and scaling factor"); | ||||
|   ops.def("moe_align_block_size", &moe_align_block_size, | ||||
|           "Aligning the number of tokens to be processed by each expert such " | ||||
|           "that it is divisible by the block size."); | ||||
|  | ||||
|   ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, | ||||
|           "Compute int8 quantized tensor for given scaling factor"); | ||||
|  | ||||
|   // Cache ops | ||||
|   pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); | ||||
|   cache_ops.def( | ||||
|     "swap_blocks", | ||||
|     &swap_blocks, | ||||
|     "Swap in (out) the cache blocks from src to dst"); | ||||
|   cache_ops.def( | ||||
|     "copy_blocks", | ||||
|     ©_blocks, | ||||
|     "Copy the cache blocks from src to dst"); | ||||
|   cache_ops.def( | ||||
|     "reshape_and_cache", | ||||
|     &reshape_and_cache, | ||||
|     "Reshape the key and value tensors and cache them"); | ||||
|   cache_ops.def( | ||||
|     "reshape_and_cache_flash", | ||||
|     &reshape_and_cache_flash, | ||||
|     "Reshape the key and value tensors and cache them"); | ||||
|   cache_ops.def( | ||||
|     "convert_fp8", | ||||
|     &convert_fp8, | ||||
|     "Convert the key and value cache to fp8 data type"); | ||||
|   cache_ops.def("swap_blocks", &swap_blocks, | ||||
|                 "Swap in (out) the cache blocks from src to dst"); | ||||
|   cache_ops.def("copy_blocks", ©_blocks, | ||||
|                 "Copy the cache blocks from src to dst"); | ||||
|   cache_ops.def("reshape_and_cache", &reshape_and_cache, | ||||
|                 "Reshape the key and value tensors and cache them"); | ||||
|   cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, | ||||
|                 "Reshape the key and value tensors and cache them"); | ||||
|   cache_ops.def("convert_fp8", &convert_fp8, | ||||
|                 "Convert the key and value cache to fp8 data type"); | ||||
|  | ||||
|   // Cuda utils | ||||
|   pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); | ||||
|   cuda_utils.def( | ||||
|     "get_device_attribute", | ||||
|     &get_device_attribute, | ||||
|     "Gets the specified device attribute."); | ||||
|   pybind11::module cuda_utils = | ||||
|       m.def_submodule("cuda_utils", "vLLM cuda utils"); | ||||
|   cuda_utils.def("get_device_attribute", &get_device_attribute, | ||||
|                  "Gets the specified device attribute."); | ||||
|  | ||||
|   cuda_utils.def( | ||||
|     "get_max_shared_memory_per_block_device_attribute", | ||||
|     &get_max_shared_memory_per_block_device_attribute, | ||||
|     "Gets the maximum shared memory per block device attribute."); | ||||
|   cuda_utils.def("get_max_shared_memory_per_block_device_attribute", | ||||
|                  &get_max_shared_memory_per_block_device_attribute, | ||||
|                  "Gets the maximum shared memory per block device attribute."); | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   // Custom all-reduce kernels | ||||
| @ -132,5 +108,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||
|   custom_ar.def("register_graph_buffers", ®ister_graph_buffers, | ||||
|                 "register_graph_buffers"); | ||||
| #endif | ||||
|  | ||||
| } | ||||
|  | ||||
| @ -25,32 +25,28 @@ | ||||
| #include <iostream> | ||||
| #include <cstdlib> | ||||
|  | ||||
|  | ||||
| namespace vllm { | ||||
| namespace aqlm { | ||||
|  | ||||
| __global__ void Code1x16MatVec( | ||||
|   const int4* __restrict__ A, | ||||
|   const int4* __restrict__ B, | ||||
|         int4* __restrict__ C, | ||||
|   const int4* __restrict__ codebook, | ||||
|   const int prob_m, | ||||
|   const int prob_k, | ||||
|   const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long. | ||||
|   const int codebook_stride // as int4. | ||||
|     const int4* __restrict__ A, const int4* __restrict__ B, | ||||
|     int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, | ||||
|     const int prob_k, | ||||
|     const int4 codebook_a_sizes,  // cumulative sizes of A spanning each | ||||
|                                   // codebook, at most 3 long. | ||||
|     const int codebook_stride     // as int4. | ||||
| ) { | ||||
|   int a_gl_stride = prob_k / 8 / 8; | ||||
|   int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); | ||||
|   bool pred = a_gl_rd < prob_m; | ||||
|  | ||||
|   if (pred) | ||||
|   { | ||||
|     // advance to the correct codebook, this easy because we only multiply one column of the codebook. | ||||
|   if (pred) { | ||||
|     // advance to the correct codebook, this easy because we only multiply one | ||||
|     // column of the codebook. | ||||
|     auto codebook_size = &codebook_a_sizes.x; | ||||
|     while (a_gl_rd >= *codebook_size) | ||||
|     { | ||||
|         codebook += codebook_stride; | ||||
|         ++codebook_size; | ||||
|     while (a_gl_rd >= *codebook_size) { | ||||
|       codebook += codebook_stride; | ||||
|       ++codebook_size; | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -67,8 +63,7 @@ __global__ void Code1x16MatVec( | ||||
|     // We pad shared memory to avoid bank conflicts during reads | ||||
|     __syncthreads(); | ||||
|     for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { | ||||
|       if (b_gl_rd + i < prob_k / 8) | ||||
|         sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; | ||||
|       if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; | ||||
|     } | ||||
|     __syncthreads(); | ||||
|     b_gl_rd += 32 * 8; | ||||
| @ -76,22 +71,19 @@ __global__ void Code1x16MatVec( | ||||
|     int b_sh_rd = 9 * (threadIdx.x % 32); | ||||
|     if (pred && a_gl_rd < a_gl_end) { | ||||
|       const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]); | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < 8; i++) { | ||||
|         uint32_t dec[4]; | ||||
|         // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't | ||||
|         // actually help us; this brings > 2x speedup. | ||||
|         asm volatile ( | ||||
|           "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" | ||||
|           : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) | ||||
|           : "l"((void*) &codebook[enc[i]]) | ||||
|         ); | ||||
|         // We bypass the L1 cache to avoid massive amounts of memory streaming | ||||
|         // that doesn't actually help us; this brings > 2x speedup. | ||||
|         asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" | ||||
|                      : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) | ||||
|                      : "l"((void*)&codebook[enc[i]])); | ||||
|         half2* a = reinterpret_cast<half2*>(&dec); | ||||
|         half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); | ||||
|         half2 res2 = {}; | ||||
|         #pragma unroll | ||||
|         for (int j = 0; j < 4; j++) | ||||
|           res2 = __hfma2(a[j], b[j], res2); | ||||
| #pragma unroll | ||||
|         for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); | ||||
|         res += __half2float(res2.x) + __half2float(res2.y); | ||||
|         b_sh_rd++; | ||||
|       } | ||||
| @ -100,37 +92,33 @@ __global__ void Code1x16MatVec( | ||||
|   } | ||||
|  | ||||
|   if (pred) { | ||||
|     #pragma unroll | ||||
|     for (int i = 16; i > 0; i /= 2) | ||||
|       res += __shfl_down_sync(0xffffffff, res, i); | ||||
| #pragma unroll | ||||
|     for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); | ||||
|     if (threadIdx.x % 32 == 0) | ||||
|       reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); | ||||
|   } | ||||
| } | ||||
|  | ||||
| __global__ void Code2x8MatVec( | ||||
|   const int4* __restrict__ A, | ||||
|   const int4* __restrict__ B, | ||||
|         int4* __restrict__ C, | ||||
|   const int4* __restrict__ codebook, | ||||
|   int prob_m, | ||||
|   int prob_k, | ||||
|   const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long. | ||||
|   const int codebook_stride // as int4. | ||||
|     const int4* __restrict__ A, const int4* __restrict__ B, | ||||
|     int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, | ||||
|     int prob_k, | ||||
|     const int4 codebook_a_sizes,  // cumulative sizes of A spanning each | ||||
|                                   // codebook, at most 3 long. | ||||
|     const int codebook_stride     // as int4. | ||||
|  | ||||
| ) { | ||||
|   int a_gl_stride = prob_k / 8 / 8; | ||||
|   int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); | ||||
|   bool pred = a_gl_rd < prob_m; | ||||
|  | ||||
|   if (pred) | ||||
|   { | ||||
|     // advance to the correct codebook, this easy because we only multiply one column of the codebook. | ||||
|   if (pred) { | ||||
|     // advance to the correct codebook, this easy because we only multiply one | ||||
|     // column of the codebook. | ||||
|     auto codebook_size = &codebook_a_sizes.x; | ||||
|     while (a_gl_rd >= *codebook_size) | ||||
|     { | ||||
|         codebook += codebook_stride; | ||||
|         ++codebook_size; | ||||
|     while (a_gl_rd >= *codebook_size) { | ||||
|       codebook += codebook_stride; | ||||
|       ++codebook_size; | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -148,9 +136,8 @@ __global__ void Code2x8MatVec( | ||||
|  | ||||
|   for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { | ||||
|     int4 dec = codebook[i]; | ||||
|     #pragma unroll | ||||
|     for (int j = 0; j < 8; j++) | ||||
|       sh_code[8 * i + (j + lane) % 8] = dec; | ||||
| #pragma unroll | ||||
|     for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
| @ -161,8 +148,7 @@ __global__ void Code2x8MatVec( | ||||
|     // We pad shared memory to avoid bank conflicts during reads | ||||
|     __syncthreads(); | ||||
|     for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { | ||||
|       if (b_gl_rd + i < prob_k / 8) | ||||
|         sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; | ||||
|       if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; | ||||
|     } | ||||
|     __syncthreads(); | ||||
|     b_gl_rd += 32 * 8; | ||||
| @ -170,13 +156,15 @@ __global__ void Code2x8MatVec( | ||||
|     int b_sh_rd = 9 * (threadIdx.x % 32); | ||||
|     if (pred && a_gl_rd < a_gl_end) { | ||||
|       const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]); | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < 8; i++) { | ||||
|         half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); | ||||
|         half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); | ||||
|         half2*  b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); | ||||
|         half2* a0 = | ||||
|             reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); | ||||
|         half2* a1 = | ||||
|             reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); | ||||
|         half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]); | ||||
|         half2 res2 = {}; | ||||
|         #pragma unroll | ||||
| #pragma unroll | ||||
|         for (int j = 0; j < 4; j++) | ||||
|           res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); | ||||
|         res += __half2float(res2.x) + __half2float(res2.y); | ||||
| @ -187,36 +175,31 @@ __global__ void Code2x8MatVec( | ||||
|   } | ||||
|  | ||||
|   if (pred) { | ||||
|     #pragma unroll | ||||
|     for (int i = 16; i > 0; i /= 2) | ||||
|       res += __shfl_down_sync(0xffffffff, res, i); | ||||
| #pragma unroll | ||||
|     for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); | ||||
|     if (threadIdx.x % 32 == 0) | ||||
|       reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| __global__ void Code1x16Dequant( | ||||
|   const int4* __restrict__ A, | ||||
|         int4* __restrict__ C, | ||||
|   const int4* __restrict__ codebook, | ||||
|   int prob_m, | ||||
|   int prob_k, | ||||
|   const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. | ||||
|   const int codebook_stride // as int4 | ||||
|     const int4* __restrict__ A, int4* __restrict__ C, | ||||
|     const int4* __restrict__ codebook, int prob_m, int prob_k, | ||||
|     const int4 codebook_a_sizes,  // cumulative sizes of A spanning each | ||||
|                                   // codebook, at most 3 long, sums to m. | ||||
|     const int codebook_stride     // as int4 | ||||
| ) { | ||||
|   int a_gl_stride = prob_k / 8 / 8; | ||||
|   int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); | ||||
|   bool pred = a_gl_rd < prob_m; | ||||
|  | ||||
|   if (pred) | ||||
|   { | ||||
|     // advance to the correct codebook, this easy because we only multiply one column of the codebook. | ||||
|   if (pred) { | ||||
|     // advance to the correct codebook, this easy because we only multiply one | ||||
|     // column of the codebook. | ||||
|     auto codebook_size = &codebook_a_sizes.x; | ||||
|     while (a_gl_rd >= *codebook_size) | ||||
|     { | ||||
|         codebook += codebook_stride; | ||||
|         ++codebook_size; | ||||
|     while (a_gl_rd >= *codebook_size) { | ||||
|       codebook += codebook_stride; | ||||
|       ++codebook_size; | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -231,17 +214,15 @@ __global__ void Code1x16Dequant( | ||||
|   while (iters--) { | ||||
|     if (pred && a_gl_rd < a_gl_end) { | ||||
|       const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]); | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < 8; i++) { | ||||
|         int4 chunk; | ||||
|         auto dec = reinterpret_cast<uint32_t*>(&chunk); | ||||
|         // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't | ||||
|         // actually help us; this brings > 2x speedup. | ||||
|         asm volatile ( | ||||
|           "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" | ||||
|           : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) | ||||
|           : "l"((void*) &codebook[enc[i]]) | ||||
|         ); | ||||
|         // We bypass the L1 cache to avoid massive amounts of memory streaming | ||||
|         // that doesn't actually help us; this brings > 2x speedup. | ||||
|         asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" | ||||
|                      : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) | ||||
|                      : "l"((void*)&codebook[enc[i]])); | ||||
|  | ||||
|         C[a_gl_rd * 8 + i] = chunk; | ||||
|       } | ||||
| @ -250,28 +231,25 @@ __global__ void Code1x16Dequant( | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| __global__ void Code2x8Dequant( | ||||
|   const int4* __restrict__ A, | ||||
|         int4* __restrict__ C, | ||||
|   const int4* __restrict__ codebook, | ||||
|   int prob_m, | ||||
|   int prob_k, | ||||
|   const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. | ||||
|   const int codebook_stride // as int4 | ||||
|     const int4* __restrict__ A, int4* __restrict__ C, | ||||
|     const int4* __restrict__ codebook, int prob_m, int prob_k, | ||||
|     const int4 | ||||
|         codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at | ||||
|                            // most 3 long, corresponds to cols. | ||||
|     const int codebook_stride  // as int4 | ||||
| ) { | ||||
|   int a_gl_stride = prob_k / 8 / 8; | ||||
|   int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); | ||||
|   bool pred = a_gl_rd < prob_m; | ||||
|  | ||||
|   if (pred) | ||||
|   { | ||||
|     // advance to the correct codebook, this easy because we only multiply one column of the codebook. | ||||
|   if (pred) { | ||||
|     // advance to the correct codebook, this easy because we only multiply one | ||||
|     // column of the codebook. | ||||
|     auto codebook_size = &codebook_a_sizes.x; | ||||
|     while (a_gl_rd >= *codebook_size) | ||||
|     { | ||||
|         codebook += codebook_stride; | ||||
|         ++codebook_size; | ||||
|     while (a_gl_rd >= *codebook_size) { | ||||
|       codebook += codebook_stride; | ||||
|       ++codebook_size; | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -290,9 +268,8 @@ __global__ void Code2x8Dequant( | ||||
|  | ||||
|   for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { | ||||
|     int4 dec = codebook[i]; | ||||
|     #pragma unroll | ||||
|     for (int j = 0; j < 8; j++) | ||||
|       sh_code[8 * i + (j + lane) % 8] = dec; | ||||
| #pragma unroll | ||||
|     for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
| @ -302,12 +279,14 @@ __global__ void Code2x8Dequant( | ||||
|   while (iters--) { | ||||
|     if (pred && a_gl_rd < a_gl_end) { | ||||
|       const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]); | ||||
|       #pragma unroll | ||||
| #pragma unroll | ||||
|       for (int i = 0; i < 8; i++) { | ||||
|         int4 chunk; | ||||
|         half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); | ||||
|         half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); | ||||
|         #pragma unroll | ||||
|         half2* a0 = | ||||
|             reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]); | ||||
|         half2* a1 = | ||||
|             reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]); | ||||
| #pragma unroll | ||||
|         for (int j = 0; j < 4; j++) | ||||
|           reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]); | ||||
|         C[a_gl_rd * 8 + i] = chunk; | ||||
| @ -317,22 +296,15 @@ __global__ void Code2x8Dequant( | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline int ceildiv(int a, int b) { | ||||
|   return (a + b - 1) / b; | ||||
| } | ||||
| inline int ceildiv(int a, int b) { return (a + b - 1) / b; } | ||||
|  | ||||
| const int THREAD_M = 16; | ||||
|  | ||||
| void  code1x16_matvec_cuda( | ||||
|   const void* __restrict__ A, | ||||
|   const void* __restrict__ B, | ||||
|         void* __restrict__ C, | ||||
|   const void* __restrict__ codebook, | ||||
|   int prob_m, | ||||
|   int prob_k, | ||||
|   const int4 codebook_a_sizes, | ||||
|   const int codebook_stride | ||||
| ) { | ||||
| void code1x16_matvec_cuda(const void* __restrict__ A, | ||||
|                           const void* __restrict__ B, void* __restrict__ C, | ||||
|                           const void* __restrict__ codebook, int prob_m, | ||||
|                           int prob_k, const int4 codebook_a_sizes, | ||||
|                           const int codebook_stride) { | ||||
|   int sms; | ||||
|   cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); | ||||
|   int waves = 0; | ||||
| @ -345,28 +317,16 @@ void  code1x16_matvec_cuda( | ||||
|   int blocks = ceildiv(prob_m, thread_m); | ||||
|   int threads = 32 * thread_m; | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|   Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>( | ||||
|     (const int4*) A, | ||||
|     (const int4*) B, | ||||
|     (int4*) C, | ||||
|     (const int4*) codebook, | ||||
|     prob_m, | ||||
|     prob_k, | ||||
|     codebook_a_sizes, | ||||
|     codebook_stride | ||||
|   ); | ||||
|   Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>( | ||||
|       (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, | ||||
|       prob_k, codebook_a_sizes, codebook_stride); | ||||
| } | ||||
|  | ||||
| void  code2x8_matvec_cuda( | ||||
|   const void* __restrict__ A, | ||||
|   const void* __restrict__ B, | ||||
|         void* __restrict__ C, | ||||
|   const void* __restrict__ codebook, | ||||
|   int prob_m, | ||||
|   int prob_k, | ||||
|   const int4 codebook_a_sizes, | ||||
|   const int codebook_stride | ||||
| ) { | ||||
| void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, | ||||
|                          void* __restrict__ C, | ||||
|                          const void* __restrict__ codebook, int prob_m, | ||||
|                          int prob_k, const int4 codebook_a_sizes, | ||||
|                          const int codebook_stride) { | ||||
|   int sms; | ||||
|   cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); | ||||
|   int waves = 0; | ||||
| @ -379,30 +339,20 @@ void  code2x8_matvec_cuda( | ||||
|   int blocks = ceildiv(prob_m, thread_m); | ||||
|   int threads = 32 * thread_m; | ||||
|   int shared = 16 * (2 * 256 * 8 + 32 * 9); | ||||
|   cudaFuncSetAttribute( | ||||
|     Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared | ||||
|   ); | ||||
|   cudaFuncSetAttribute(Code2x8MatVec, | ||||
|                        cudaFuncAttributeMaxDynamicSharedMemorySize, shared); | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|   Code2x8MatVec<<<blocks, threads, shared, stream>>>( | ||||
|     (const int4*) A, | ||||
|     (const int4*) B, | ||||
|     (int4*) C, | ||||
|     (const int4*) codebook, | ||||
|     prob_m, | ||||
|     prob_k, | ||||
|     codebook_a_sizes, | ||||
|     codebook_stride | ||||
|   ); | ||||
|       (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, | ||||
|       prob_k, codebook_a_sizes, codebook_stride); | ||||
| } | ||||
|  | ||||
| void code1x16_dequant_cuda( | ||||
|   const void* __restrict__ A, | ||||
|         void* __restrict__ C, | ||||
|   const void* __restrict__ codebook, | ||||
|   int prob_m, | ||||
|   int prob_k, | ||||
|   const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long. | ||||
|   const int codebook_stride // as int4. | ||||
|     const void* __restrict__ A, void* __restrict__ C, | ||||
|     const void* __restrict__ codebook, int prob_m, int prob_k, | ||||
|     const int4 codebook_a_sizes,  // cumulative sizes of A spanning each | ||||
|                                   // codebook, at most 3 long. | ||||
|     const int codebook_stride     // as int4. | ||||
| ) { | ||||
|   int sms; | ||||
|   cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); | ||||
| @ -417,25 +367,21 @@ void code1x16_dequant_cuda( | ||||
|   int threads = 32 * thread_m; | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|   Code1x16Dequant<<<blocks, threads, 0, stream>>>( | ||||
|     (const int4*) A, | ||||
|     (int4*) C, | ||||
|     (const int4*) codebook, | ||||
|     prob_m, | ||||
|     prob_k, | ||||
|     codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long. | ||||
|     codebook_stride // as int4. | ||||
|       (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, | ||||
|       codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at | ||||
|                          // most 3 long. | ||||
|       codebook_stride    // as int4. | ||||
|   ); | ||||
| } | ||||
|  | ||||
| // Dequantizes the code and codebook into weights. | ||||
| void  code2x8_dequant_cuda( | ||||
|   const void* __restrict__ A, | ||||
|         void* __restrict__ C, | ||||
|   const void* __restrict__ codebook, | ||||
|   int prob_m, | ||||
|   int prob_k, | ||||
|   const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. | ||||
|   const int codebook_stride // as int4 | ||||
| void code2x8_dequant_cuda( | ||||
|     const void* __restrict__ A, void* __restrict__ C, | ||||
|     const void* __restrict__ codebook, int prob_m, int prob_k, | ||||
|     const int4 | ||||
|         codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at | ||||
|                            // most 3 long, corresponds to cols. | ||||
|     const int codebook_stride  // as int4 | ||||
| ) { | ||||
|   int sms; | ||||
|   cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); | ||||
| @ -451,74 +397,50 @@ void  code2x8_dequant_cuda( | ||||
|   int shared = 16 * (2 * 256 * 8 + 32 * 9); | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|  | ||||
|   cudaFuncSetAttribute( | ||||
|     Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared | ||||
|   ); | ||||
|   cudaFuncSetAttribute(Code2x8Dequant, | ||||
|                        cudaFuncAttributeMaxDynamicSharedMemorySize, shared); | ||||
|   Code2x8Dequant<<<blocks, threads, shared, stream>>>( | ||||
|     (const int4*) A, | ||||
|     (int4*) C, | ||||
|     (const int4*) codebook, | ||||
|     prob_m, | ||||
|     prob_k, | ||||
|     codebook_a_sizes, | ||||
|     codebook_stride | ||||
|   ); | ||||
|       (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, | ||||
|       codebook_a_sizes, codebook_stride); | ||||
| } | ||||
|  | ||||
| int codebook_stride(const torch::Tensor& codebooks) | ||||
| { | ||||
| int codebook_stride(const torch::Tensor& codebooks) { | ||||
|   return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); | ||||
| } | ||||
|  | ||||
| void code1x16_matvec( | ||||
|   const torch::Tensor& A, | ||||
|   const torch::Tensor& B, | ||||
|         torch::Tensor& C, | ||||
|   const torch::Tensor& codebook, | ||||
|   const int4 codebook_a_sizes  // cumulative sizes of A spanning each codebook, at most 3 long. | ||||
|     const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, | ||||
|     const torch::Tensor& codebook, | ||||
|     const int4 codebook_a_sizes  // cumulative sizes of A spanning each | ||||
|                                  // codebook, at most 3 long. | ||||
| ) { | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); | ||||
|   int prob_m = C.size(0); | ||||
|   int prob_k = B.size(0); | ||||
|  | ||||
|   code1x16_matvec_cuda( | ||||
|     A.data_ptr(), | ||||
|     B.data_ptr(), | ||||
|     C.data_ptr(), | ||||
|     codebook.data_ptr(), | ||||
|     prob_m, | ||||
|     prob_k, | ||||
|     codebook_a_sizes, | ||||
|     codebook_stride(codebook) | ||||
|   ); | ||||
|   code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), | ||||
|                        codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, | ||||
|                        codebook_stride(codebook)); | ||||
| } | ||||
|  | ||||
| torch::Tensor code1x16_matmat( | ||||
|   const torch::Tensor& input, | ||||
|   const torch::Tensor& codes, | ||||
|   const torch::Tensor& codebooks, | ||||
|   const torch::Tensor& scales, | ||||
|   const int4 codebook_a_sizes, | ||||
|   const std::optional<torch::Tensor>& bias) { | ||||
| torch::Tensor code1x16_matmat(const torch::Tensor& input, | ||||
|                               const torch::Tensor& codes, | ||||
|                               const torch::Tensor& codebooks, | ||||
|                               const torch::Tensor& scales, | ||||
|                               const int4 codebook_a_sizes, | ||||
|                               const std::optional<torch::Tensor>& bias) { | ||||
|   auto input_sizes = input.sizes(); | ||||
|   auto out_features = codes.size(0) * codebooks.size(2); | ||||
|   auto flat_input = input.reshape({-1, input.size(-1)}); | ||||
|   auto flat_output = torch::empty({flat_input.size(0), out_features}, | ||||
|     torch::TensorOptions() | ||||
|       .dtype(input.dtype()) | ||||
|       .device(input.device()) | ||||
|   ); | ||||
|   auto flat_output = torch::empty( | ||||
|       {flat_input.size(0), out_features}, | ||||
|       torch::TensorOptions().dtype(input.dtype()).device(input.device())); | ||||
|  | ||||
|   for (int i = 0; i < flat_input.size(0); ++i) { | ||||
|     auto input_vec = flat_input.index({i}); | ||||
|     auto output_vec = flat_output.index({i}); | ||||
|     code1x16_matvec( | ||||
|       codes.squeeze(2), | ||||
|       input_vec, | ||||
|       output_vec, | ||||
|       codebooks, | ||||
|       codebook_a_sizes | ||||
|     ); | ||||
|     code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, | ||||
|                     codebook_a_sizes); | ||||
|   } | ||||
|   flat_output *= scales.flatten().unsqueeze(0); | ||||
|  | ||||
| @ -533,55 +455,35 @@ torch::Tensor code1x16_matmat( | ||||
|   return output; | ||||
| } | ||||
|  | ||||
| void code2x8_matvec( | ||||
|   const torch::Tensor& A, | ||||
|   const torch::Tensor& B, | ||||
|         torch::Tensor& C, | ||||
|   const torch::Tensor& codebook, | ||||
|   const int4 codebook_a_sizes | ||||
| ) { | ||||
| void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, | ||||
|                     torch::Tensor& C, const torch::Tensor& codebook, | ||||
|                     const int4 codebook_a_sizes) { | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); | ||||
|   int prob_m = C.size(0); | ||||
|   int prob_k = B.size(0); | ||||
|   code2x8_matvec_cuda( | ||||
|     A.data_ptr(), | ||||
|     B.data_ptr(), | ||||
|     C.data_ptr(), | ||||
|     codebook.data_ptr(), | ||||
|     prob_m, | ||||
|     prob_k, | ||||
|     codebook_a_sizes, | ||||
|     2 * codebook_stride(codebook) | ||||
|   ); | ||||
|   code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), | ||||
|                       codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, | ||||
|                       2 * codebook_stride(codebook)); | ||||
| } | ||||
|  | ||||
| torch::Tensor code2x8_matmat( | ||||
|   const torch::Tensor& input, | ||||
|   const torch::Tensor& codes, | ||||
|   const torch::Tensor& codebooks, | ||||
|   const torch::Tensor& scales, | ||||
|   const int4 codebook_a_sizes, | ||||
|   const std::optional<torch::Tensor>& bias | ||||
| ) { | ||||
| torch::Tensor code2x8_matmat(const torch::Tensor& input, | ||||
|                              const torch::Tensor& codes, | ||||
|                              const torch::Tensor& codebooks, | ||||
|                              const torch::Tensor& scales, | ||||
|                              const int4 codebook_a_sizes, | ||||
|                              const std::optional<torch::Tensor>& bias) { | ||||
|   auto input_sizes = input.sizes(); | ||||
|   auto out_features = codes.size(0) * codebooks.size(2); | ||||
|   auto flat_input = input.reshape({-1, input.size(-1)}); | ||||
|   auto flat_output = torch::empty({flat_input.size(0), out_features}, | ||||
|     torch::TensorOptions() | ||||
|       .dtype(input.dtype()) | ||||
|       .device(input.device()) | ||||
|   ); | ||||
|   auto flat_output = torch::empty( | ||||
|       {flat_input.size(0), out_features}, | ||||
|       torch::TensorOptions().dtype(input.dtype()).device(input.device())); | ||||
|  | ||||
|   for (int i = 0; i < flat_input.size(0); ++i) { | ||||
|     auto input_vec = flat_input.index({i}); | ||||
|     auto output_vec = flat_output.index({i}); | ||||
|     code2x8_matvec( | ||||
|       codes.squeeze(2), | ||||
|       input_vec, | ||||
|       output_vec, | ||||
|       codebooks, | ||||
|       codebook_a_sizes | ||||
|     ); | ||||
|     code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, | ||||
|                    codebook_a_sizes); | ||||
|   } | ||||
|   flat_output *= scales.flatten().unsqueeze(0); | ||||
|   if (bias.has_value()) { | ||||
| @ -596,64 +498,56 @@ torch::Tensor code2x8_matmat( | ||||
| } | ||||
|  | ||||
| // Accumulate the partition sizes. | ||||
| int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) | ||||
| { | ||||
| int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { | ||||
|   int4 cumulative_sizes; | ||||
|   auto cumulative_size = &cumulative_sizes.x; | ||||
|   int i = 0; | ||||
|   int last = 0; | ||||
|   assert(codebook_partition_sizes.size(0) <= 4); | ||||
|   for (; i <  codebook_partition_sizes.size(0); ++i, ++cumulative_size) | ||||
|   { | ||||
|   for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { | ||||
|     *cumulative_size = codebook_partition_sizes[i].item<int>() + last; | ||||
|     last = *cumulative_size; | ||||
|   } | ||||
|   // fill in the rest with unreachable. | ||||
|   for (; i < 4; ++i, ++cumulative_size) | ||||
|   { | ||||
|     *cumulative_size = last*10; | ||||
|   for (; i < 4; ++i, ++cumulative_size) { | ||||
|     *cumulative_size = last * 10; | ||||
|   } | ||||
|   return cumulative_sizes; | ||||
| } | ||||
|  | ||||
| } // namespace aqlm | ||||
| } // namespace vllm | ||||
| }  // namespace aqlm | ||||
| }  // namespace vllm | ||||
|  | ||||
|  | ||||
| torch::Tensor aqlm_gemm( | ||||
|   const torch::Tensor& input, | ||||
|   const torch::Tensor& codes, | ||||
|   const torch::Tensor& codebooks, | ||||
|   const torch::Tensor& scales, | ||||
|   const torch::Tensor& codebook_partition_sizes, | ||||
|   const std::optional<torch::Tensor>& bias | ||||
| ) | ||||
| { | ||||
|   int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); | ||||
| torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, | ||||
|                         const torch::Tensor& codebooks, | ||||
|                         const torch::Tensor& scales, | ||||
|                         const torch::Tensor& codebook_partition_sizes, | ||||
|                         const std::optional<torch::Tensor>& bias) { | ||||
|   int4 cumulative_sizes = | ||||
|       vllm::aqlm::accumulate_sizes(codebook_partition_sizes); | ||||
|  | ||||
|   int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); | ||||
|   int const entries = codebooks.size(1); | ||||
|  | ||||
|   if (nbooks == 1 && entries == (1 << 16)) | ||||
|   {  | ||||
|     return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); | ||||
|   if (nbooks == 1 && entries == (1 << 16)) { | ||||
|     return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, | ||||
|                                        cumulative_sizes, bias); | ||||
|   } | ||||
|   if (nbooks == 2 && entries == (1 << 8)) | ||||
|   { | ||||
|     return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); | ||||
|   if (nbooks == 2 && entries == (1 << 8)) { | ||||
|     return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, | ||||
|                                       cumulative_sizes, bias); | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") | ||||
|   TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, | ||||
|               " entries is not currently supported.") | ||||
|   return {}; | ||||
| } | ||||
|  | ||||
| torch::Tensor aqlm_dequant( | ||||
|   const torch::Tensor& codes, | ||||
|   const torch::Tensor& codebooks, | ||||
|   const torch::Tensor& codebook_partition_sizes | ||||
| ) | ||||
| { | ||||
|   int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); | ||||
| torch::Tensor aqlm_dequant(const torch::Tensor& codes, | ||||
|                            const torch::Tensor& codebooks, | ||||
|                            const torch::Tensor& codebook_partition_sizes) { | ||||
|   int4 cumulative_sizes = | ||||
|       vllm::aqlm::accumulate_sizes(codebook_partition_sizes); | ||||
|  | ||||
|   int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); | ||||
|   int const entries = codebooks.size(1); | ||||
| @ -668,45 +562,37 @@ torch::Tensor aqlm_dequant( | ||||
|   assert(out_features = codebook_partition_sizes.sum().item<int>()); | ||||
|  | ||||
|   auto weights = torch::empty({out_features, in_features}, | ||||
|     torch::TensorOptions() | ||||
|       .dtype(codebooks.dtype()) | ||||
|       .device(codebooks.device()) | ||||
|   ); | ||||
|                               torch::TensorOptions() | ||||
|                                   .dtype(codebooks.dtype()) | ||||
|                                   .device(codebooks.device())); | ||||
|  | ||||
|   if (nbooks == 1 && entries == (1 << 16)) | ||||
|   { | ||||
|     vllm::aqlm::code1x16_dequant_cuda( | ||||
|       codes.data_ptr(), | ||||
|       weights.data_ptr(), | ||||
|       codebooks.data_ptr(), | ||||
|       out_features, | ||||
|       in_features, | ||||
|       cumulative_sizes, | ||||
|       vllm::aqlm::codebook_stride(codebooks)); | ||||
|   if (nbooks == 1 && entries == (1 << 16)) { | ||||
|     vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), | ||||
|                                       codebooks.data_ptr(), out_features, | ||||
|                                       in_features, cumulative_sizes, | ||||
|                                       vllm::aqlm::codebook_stride(codebooks)); | ||||
|  | ||||
|     // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) | ||||
|     // weights *= scales.index({"...", 0, 0}); | ||||
|     // if you wanted to flip to scaling the weights, (though it's 30%-ish slower | ||||
|     // and not consistent with gemv implementation.) weights *= | ||||
|     // scales.index({"...", 0, 0}); | ||||
|  | ||||
|      return weights; | ||||
|     return weights; | ||||
|   } | ||||
|  | ||||
|   if (nbooks == 2 && entries == (1 << 8)) | ||||
|   { | ||||
|      vllm::aqlm::code2x8_dequant_cuda( | ||||
|         codes.data_ptr(),  | ||||
|         weights.data_ptr(),  | ||||
|         codebooks.data_ptr(),  | ||||
|         out_features, | ||||
|         in_features,  | ||||
|         cumulative_sizes,  | ||||
|         vllm::aqlm::codebook_stride(codebooks)); | ||||
|   if (nbooks == 2 && entries == (1 << 8)) { | ||||
|     vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), | ||||
|                                      codebooks.data_ptr(), out_features, | ||||
|                                      in_features, cumulative_sizes, | ||||
|                                      vllm::aqlm::codebook_stride(codebooks)); | ||||
|  | ||||
|     // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) | ||||
|     // weights *= scales.index({"...", 0, 0}); | ||||
|     // if you wanted to flip to scaling the weights, (though it's 30%-ish slower | ||||
|     // and not consistent with gemv implementation) weights *= | ||||
|     // scales.index({"...", 0, 0}); | ||||
|  | ||||
|      return weights; | ||||
|     return weights; | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") | ||||
|   TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, | ||||
|               " entries is not currently supported.") | ||||
|   return {}; | ||||
| } | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| /* | ||||
| Adapted from https://github.com/mit-han-lab/llm-awq | ||||
| Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h | ||||
| Modified from NVIDIA FasterTransformer: | ||||
| https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h | ||||
| @article{lin2023awq, | ||||
|   title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, | ||||
|   author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, | ||||
|   journal={arXiv}, | ||||
|   year={2023} | ||||
|   title={AWQ: Activation-aware Weight Quantization for LLM Compression and | ||||
| Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, | ||||
| Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} | ||||
| } | ||||
| */ | ||||
|  | ||||
| @ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor | ||||
| namespace vllm { | ||||
| namespace awq { | ||||
|  | ||||
| __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) | ||||
| { | ||||
| __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 | ||||
|   assert(false); | ||||
| #else | ||||
|     uint4 result; | ||||
|   uint4 result; | ||||
|  | ||||
|     uint32_t*      h   = reinterpret_cast<uint32_t*>(&result); | ||||
|     uint32_t const i4s = reinterpret_cast<uint32_t const&>(source); | ||||
|   uint32_t* h = reinterpret_cast<uint32_t*>(&result); | ||||
|   uint32_t const i4s = reinterpret_cast<uint32_t const&>(source); | ||||
|  | ||||
|     // First, we extract the i4s and construct an intermediate fp16 number. | ||||
|     static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa; | ||||
|     static constexpr uint32_t BOTTOM_MASK           = 0x000f000f; | ||||
|     static constexpr uint32_t TOP_MASK              = 0x00f000f0; | ||||
|     static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; | ||||
|   // First, we extract the i4s and construct an intermediate fp16 number. | ||||
|   static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; | ||||
|   static constexpr uint32_t BOTTOM_MASK = 0x000f000f; | ||||
|   static constexpr uint32_t TOP_MASK = 0x00f000f0; | ||||
|   static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; | ||||
|  | ||||
|     // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing | ||||
|     // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. | ||||
|     // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and | ||||
|     // elt_67 to fp16 without having to shift them to the bottom bits before hand. | ||||
|   // Note that the entire sequence only requires 1 shift instruction. This is | ||||
|   // thanks to the register packing format and the fact that we force our | ||||
|   // integers to be unsigned, and account for this in the fp16 subtractions. In | ||||
|   // addition, I exploit the fact that sub and fma have the same throughput in | ||||
|   // order to convert elt_23 and elt_67 to fp16 without having to shift them to | ||||
|   // the bottom bits before hand. | ||||
|  | ||||
|     // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue | ||||
|     // immediately before required. | ||||
|     const uint32_t top_i4s = i4s >> 8; | ||||
|     // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 | ||||
|     asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                     : "=r"(h[0]) | ||||
|                     : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); | ||||
|     // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 | ||||
|     asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                     : "=r"(h[1]) | ||||
|                     : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); | ||||
|     // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 | ||||
|     asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                     : "=r"(h[2]) | ||||
|                     : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); | ||||
|     // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 | ||||
|     asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                     : "=r"(h[3]) | ||||
|                     : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); | ||||
|   // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW | ||||
|   // dependency if we issue immediately before required. | ||||
|   const uint32_t top_i4s = i4s >> 8; | ||||
|   // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 | ||||
|   asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                : "=r"(h[0]) | ||||
|                : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), | ||||
|                  "n"(immLut)); | ||||
|   // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 | ||||
|   asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                : "=r"(h[1]) | ||||
|                : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), | ||||
|                  "n"(immLut)); | ||||
|   // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 | ||||
|   asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                : "=r"(h[2]) | ||||
|                : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), | ||||
|                  "n"(immLut)); | ||||
|   // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 | ||||
|   asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" | ||||
|                : "=r"(h[3]) | ||||
|                : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), | ||||
|                  "n"(immLut)); | ||||
|  | ||||
|     // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the | ||||
|     // half2 ctor. In this case, I chose performance reliability over code readability. | ||||
|   // I use inline PTX below because I am not sure if the compiler will emit | ||||
|   // float2half instructions if I use the half2 ctor. In this case, I chose | ||||
|   // performance reliability over code readability. | ||||
|  | ||||
|     // This is the half2 {1032, 1032} represented as an integer. | ||||
|     // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; | ||||
|     // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] | ||||
|     static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; | ||||
|     // This is the half2 {1 / 16, 1 / 16} represented as an integer. | ||||
|     static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; | ||||
|     // This is the half2 {-72, -72} represented as an integer. | ||||
|     // static constexpr uint32_t NEG_72 = 0xd480d480; | ||||
|     // Haotian: Let's use {-64, -64}. | ||||
|     static constexpr uint32_t NEG_64 = 0xd400d400; | ||||
|   // This is the half2 {1032, 1032} represented as an integer. | ||||
|   // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; | ||||
|   // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] | ||||
|   static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; | ||||
|   // This is the half2 {1 / 16, 1 / 16} represented as an integer. | ||||
|   static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; | ||||
|   // This is the half2 {-72, -72} represented as an integer. | ||||
|   // static constexpr uint32_t NEG_72 = 0xd480d480; | ||||
|   // Haotian: Let's use {-64, -64}. | ||||
|   static constexpr uint32_t NEG_64 = 0xd400d400; | ||||
|  | ||||
|     // Finally, we construct the output numbers. | ||||
|     // Convert elt_01 | ||||
|     asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); | ||||
|     // Convert elt_23 | ||||
|     asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); | ||||
|     // Convert elt_45 | ||||
|     asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); | ||||
|     // Convert elt_67 | ||||
|     asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); | ||||
|   // Finally, we construct the output numbers. | ||||
|   // Convert elt_01 | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                : "=r"(h[0]) | ||||
|                : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); | ||||
|   // Convert elt_23 | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                : "=r"(h[1]) | ||||
|                : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); | ||||
|   // Convert elt_45 | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                : "=r"(h[2]) | ||||
|                : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); | ||||
|   // Convert elt_67 | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                : "=r"(h[3]) | ||||
|                : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); | ||||
|  | ||||
|     return result; | ||||
|   return result; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| } // namespace awq | ||||
| } // namespace vllm | ||||
| }  // namespace awq | ||||
| }  // namespace vllm | ||||
|  | ||||
| @ -1,14 +1,12 @@ | ||||
| /* | ||||
| Adapted from https://github.com/mit-han-lab/llm-awq | ||||
| @article{lin2023awq, | ||||
|   title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, | ||||
|   author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, | ||||
|   journal={arXiv}, | ||||
|   year={2023} | ||||
|   title={AWQ: Activation-aware Weight Quantization for LLM Compression and | ||||
| Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, | ||||
| Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} | ||||
| } | ||||
|  */ | ||||
|  | ||||
|  | ||||
| #include <torch/extension.h> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
|  | ||||
| @ -20,26 +18,20 @@ namespace vllm { | ||||
| namespace awq { | ||||
|  | ||||
| // Pack two half values. | ||||
| static inline __device__ __host__ unsigned | ||||
| __pack_half2(const half x, const half y) { | ||||
|   unsigned v0 = *((unsigned short *)&x); | ||||
|   unsigned v1 = *((unsigned short *)&y); | ||||
| static inline __device__ __host__ unsigned __pack_half2(const half x, | ||||
|                                                         const half y) { | ||||
|   unsigned v0 = *((unsigned short*)&x); | ||||
|   unsigned v1 = *((unsigned short*)&y); | ||||
|   return (v1 << 16) | v0; | ||||
| } | ||||
|  | ||||
| template<int N> | ||||
| __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( | ||||
|   int G, | ||||
|   int split_k_iters, | ||||
|   half* __restrict__ A, | ||||
|   int* __restrict__ B, | ||||
|   half* __restrict__ scaling_factors, | ||||
|   int* __restrict__ zeros, | ||||
|   int M, | ||||
|   int IC, | ||||
|   int OC, | ||||
|   half* __restrict__ C) | ||||
| { | ||||
| template <int N> | ||||
| __global__ void __launch_bounds__(64) | ||||
|     gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, | ||||
|                                     half* __restrict__ A, int* __restrict__ B, | ||||
|                                     half* __restrict__ scaling_factors, | ||||
|                                     int* __restrict__ zeros, int M, int IC, | ||||
|                                     int OC, half* __restrict__ C) { | ||||
|   // Only support matrix n = 64 or 128 | ||||
|   assert(N == 64 || N == 128); | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 | ||||
| @ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( | ||||
|   static constexpr int row_stride = 2 * 32 * 8 / N; | ||||
|   bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; | ||||
|   // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 | ||||
|   bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id | ||||
|   bool ld_A_flag = | ||||
|       (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + | ||||
|        threadIdx.x * 8 / 32) < M;  // threadIdx.y is warp_id | ||||
|   // bool wb_C_flag = (threadIdx.x / 4) < M; | ||||
|  | ||||
|   half* A_ptr = A | ||||
|                 + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC | ||||
|                 + (((int)threadIdx.x) % (32 / 8)) * 8; | ||||
|   half* A_ptr = | ||||
|       A + | ||||
|       (((int)blockIdx_y) / j_factors1 * 16 + | ||||
|        (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * | ||||
|           IC + | ||||
|       (((int)threadIdx.x) % (32 / 8)) * 8; | ||||
|  | ||||
|   int* B_ptr = B | ||||
|             + ((int)threadIdx.y) * (OC / 8) * (256 / N) | ||||
|             + (((int)threadIdx.x) / (N / 8)) * (OC / 8) | ||||
|             + (((int)blockIdx_y) % j_factors1) * (N / 8) | ||||
|             + (((int)threadIdx.x) % (N / 8)) * 1; | ||||
| // Why * 1 in the above line? | ||||
|   int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + | ||||
|                (((int)threadIdx.x) / (N / 8)) * (OC / 8) + | ||||
|                (((int)blockIdx_y) % j_factors1) * (N / 8) + | ||||
|                (((int)threadIdx.x) % (N / 8)) * 1; | ||||
|   // Why * 1 in the above line? | ||||
|  | ||||
|   half* A_shared_ptr = A_shared | ||||
|                     + ((int)threadIdx.y) * row_stride_warp * (32 + 8) | ||||
|                     + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) | ||||
|                     + (((int)threadIdx.x) % (32 / 8) ) * 8; | ||||
|   half* A_shared_ptr = A_shared + | ||||
|                        ((int)threadIdx.y) * row_stride_warp * (32 + 8) + | ||||
|                        (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + | ||||
|                        (((int)threadIdx.x) % (32 / 8)) * 8; | ||||
|  | ||||
|   half* B_shared_ptr = B_shared | ||||
|                     + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) | ||||
|                     + (((int)threadIdx.x) / (N / 8)) * (N + 8) | ||||
|                     + (((int)threadIdx.x) % (N / 8)) * 8; | ||||
|   half* B_shared_ptr = B_shared + | ||||
|                        ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + | ||||
|                        (((int)threadIdx.x) / (N / 8)) * (N + 8) + | ||||
|                        (((int)threadIdx.x) % (N / 8)) * 8; | ||||
|  | ||||
|   int* zeros_ptr = zeros | ||||
|                 + (((int)blockIdx_y) % j_factors1) * (N / 8) | ||||
|                 + ((int)threadIdx.x) % (N / 8); | ||||
|   int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + | ||||
|                    ((int)threadIdx.x) % (N / 8); | ||||
|  | ||||
|   half* scaling_factors_ptr = scaling_factors | ||||
|                             + (((int)blockIdx_y) % j_factors1) * N | ||||
|                             + (((int)threadIdx.x) % (N / 8)) * 8; | ||||
|   half* scaling_factors_ptr = scaling_factors + | ||||
|                               (((int)blockIdx_y) % j_factors1) * N + | ||||
|                               (((int)threadIdx.x) % (N / 8)) * 8; | ||||
|  | ||||
|   half* C_ptr = C | ||||
|               + static_cast<long long>(blockIdx_z) * M * OC        // blockIdz.x -> split_k dim | ||||
|               + (((int)blockIdx_y) % j_factors1) * N | ||||
|               + ((int)threadIdx.y) * (N / 2) | ||||
|               + (((int)threadIdx.x) % 4) * 2; | ||||
|   half* C_ptr = | ||||
|       C + | ||||
|       static_cast<long long>(blockIdx_z) * M * OC  // blockIdz.x -> split_k dim | ||||
|       + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + | ||||
|       (((int)threadIdx.x) % 4) * 2; | ||||
|  | ||||
|   // preload s.f. and zeros | ||||
|   int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; | ||||
| @ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( | ||||
|     int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; | ||||
|     __syncthreads(); | ||||
|     // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 | ||||
|     if (ld_A_flag) | ||||
|     { | ||||
|     if (ld_A_flag) { | ||||
|       *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|     } else { | ||||
|       *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); | ||||
|     } | ||||
|  | ||||
|     // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { | ||||
|     uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); | ||||
|     uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); | ||||
|     uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); | ||||
|     uint4 B_loaded_scale = | ||||
|         *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); | ||||
|     /* | ||||
|     if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ | ||||
|       printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); | ||||
|     if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && | ||||
|     threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, | ||||
|     B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, | ||||
|     B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); | ||||
|     } | ||||
|     */ | ||||
|     // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); | ||||
|     int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); | ||||
|  | ||||
|     for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { | ||||
|  | ||||
|       // B: 32 x 136 (128+8) float16 | ||||
|       // each warp: 32 x 4 | ||||
|       // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 | ||||
|       // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); | ||||
|       // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) | ||||
|       uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); | ||||
|       // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus | ||||
|       // zero -> WB UINT4 | ||||
|       // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * | ||||
|       // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) | ||||
|       // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * | ||||
|       // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * | ||||
|       // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * | ||||
|       // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) | ||||
|       uint32_t B_loaded = | ||||
|           *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); | ||||
|       uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); | ||||
|       //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); | ||||
|       // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / | ||||
|       // 8)) * 8); | ||||
|  | ||||
|       // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); | ||||
|       // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x | ||||
|       // % (cta_N / 8)) * 8); | ||||
|       // - zero and * scale | ||||
|       // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); | ||||
|       // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = | ||||
|       // q * scale - zero * scale. | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                    : "=r"(B_loaded_fp16.x) | ||||
|                    : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                    : "=r"(B_loaded_fp16.x) | ||||
|                    : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                    : "=r"(B_loaded_fp16.y) | ||||
|                    : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                    : "=r"(B_loaded_fp16.y) | ||||
|                    : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                    : "=r"(B_loaded_fp16.z) | ||||
|                    : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                    : "=r"(B_loaded_fp16.z) | ||||
|                    : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); | ||||
|       asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                    : "=r"(B_loaded_fp16.w) | ||||
|                    : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); | ||||
|       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                    : "=r"(B_loaded_fp16.w) | ||||
|                    : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); | ||||
|       /* | ||||
|       if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ | ||||
|         printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); | ||||
|       if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == | ||||
|       0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", | ||||
|       B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); | ||||
|       } | ||||
|       */ | ||||
|  | ||||
|       // write back | ||||
|       *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; | ||||
|       *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = | ||||
|           B_loaded_fp16; | ||||
|     } | ||||
|     __syncthreads(); | ||||
|  | ||||
| @ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( | ||||
|       { | ||||
|         unsigned int addr; | ||||
|         __asm__ __volatile__( | ||||
|           "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" | ||||
|           : "=r"(addr) | ||||
|           : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) | ||||
|         ); | ||||
|  | ||||
|             "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " | ||||
|             "addr; }\n" | ||||
|             : "=r"(addr) | ||||
|             : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + | ||||
|                           (((((int)threadIdx.x) & 15) * 40) + | ||||
|                            ((((int)threadIdx.x) >> 4) * 8))))); | ||||
|  | ||||
|         __asm__ __volatile__( | ||||
|           "ldmatrix.sync.aligned.m8n8.x4.shared.b16" | ||||
|           "{%0, %1, %2, %3}, [%4];\n" | ||||
|           : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) | ||||
|           : "r"(addr) | ||||
|         ); | ||||
|             "ldmatrix.sync.aligned.m8n8.x4.shared.b16" | ||||
|             "{%0, %1, %2, %3}, [%4];\n" | ||||
|             : "=r"(((unsigned*)(A_shared_warp + 0))[0]), | ||||
|               "=r"(((unsigned*)(A_shared_warp + 0))[1]), | ||||
|               "=r"(((unsigned*)(A_shared_warp + 0))[2]), | ||||
|               "=r"(((unsigned*)(A_shared_warp + 0))[3]) | ||||
|             : "r"(addr)); | ||||
|       } | ||||
|  | ||||
|       for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { | ||||
|         { | ||||
|           unsigned int addr; | ||||
|           __asm__ __volatile__( | ||||
|             "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" | ||||
|             : "=r"(addr) | ||||
|             : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) | ||||
|           ); | ||||
|               "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " | ||||
|               "addr; }\n" | ||||
|               : "=r"(addr) | ||||
|               : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + | ||||
|                                           (((int)threadIdx.y) * (N / 2))) + | ||||
|                                          (ax1_0 * 16))])) + | ||||
|                             (((((int)threadIdx.x) & 15) * (N + 8)) + | ||||
|                              ((((int)threadIdx.x) >> 4) * 8))))); | ||||
|           __asm__ __volatile__( | ||||
|             "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" | ||||
|             "{%0, %1, %2, %3}, [%4];\n" | ||||
|             : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) | ||||
|             : "r"(addr) | ||||
|           ); | ||||
|               "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" | ||||
|               "{%0, %1, %2, %3}, [%4];\n" | ||||
|               : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), | ||||
|                 "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), | ||||
|                 "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), | ||||
|                 "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) | ||||
|               : "r"(addr)); | ||||
|         } | ||||
|       } | ||||
|       for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 | ||||
|   #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 | ||||
|         { | ||||
|           __asm__ __volatile__( | ||||
|             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) | ||||
|             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); | ||||
|               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) | ||||
|               : "r"(((unsigned*)(A_shared_warp + 0))[0]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[1]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); | ||||
|         } | ||||
|  | ||||
|         { | ||||
|           __asm__ __volatile__( | ||||
|             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) | ||||
|             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); | ||||
|               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) | ||||
|               : "r"(((unsigned*)(A_shared_warp + 0))[0]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[1]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); | ||||
|         } | ||||
|  | ||||
|         { | ||||
|           __asm__ __volatile__( | ||||
|             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) | ||||
|             : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); | ||||
|               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) | ||||
|               : "r"(((unsigned*)(A_shared_warp + 0))[2]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[3]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); | ||||
|         } | ||||
|  | ||||
|         { | ||||
|           __asm__ __volatile__( | ||||
|             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) | ||||
|             : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); | ||||
|               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" | ||||
|               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" | ||||
|               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) | ||||
|               : "r"(((unsigned*)(A_shared_warp + 0))[2]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[3]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); | ||||
|         } | ||||
| #else | ||||
|   #else | ||||
|         { | ||||
|           __asm__ __volatile__( | ||||
|             "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" | ||||
|             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" | ||||
|             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) | ||||
|             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); | ||||
|               "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" | ||||
|               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " | ||||
|               "%13};\n" | ||||
|               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), | ||||
|                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) | ||||
|               : "r"(((unsigned*)(A_shared_warp + 0))[0]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[1]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[2]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[3]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), | ||||
|                 "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); | ||||
|         } | ||||
|  | ||||
|         { | ||||
|           __asm__ __volatile__( | ||||
|             "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" | ||||
|             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" | ||||
|             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) | ||||
|             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); | ||||
|               "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" | ||||
|               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " | ||||
|               "%13};\n" | ||||
|               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), | ||||
|                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) | ||||
|               : "r"(((unsigned*)(A_shared_warp + 0))[0]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[1]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[2]), | ||||
|                 "r"(((unsigned*)(A_shared_warp + 0))[3]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), | ||||
|                 "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); | ||||
|         } | ||||
|  | ||||
| #endif | ||||
|   #endif | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
| // TODO: Shang: Hoist loop invariance. | ||||
|   // TODO: Shang: Hoist loop invariance. | ||||
|   for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { | ||||
|     for (int local_id = 0; local_id < 8; ++local_id) { | ||||
|       int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; | ||||
|       if (row_offset < M) | ||||
|       { | ||||
|         *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); | ||||
|       int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + | ||||
|                        ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; | ||||
|       if (row_offset < M) { | ||||
|         *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + | ||||
|           local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| __global__ void __launch_bounds__(64) dequantize_weights( | ||||
|     int* __restrict__ B, | ||||
|     half* __restrict__ scaling_factors, | ||||
|     int* __restrict__ zeros, | ||||
|     half* __restrict__ C, | ||||
|     int G | ||||
| ) | ||||
| { | ||||
| __global__ void __launch_bounds__(64) | ||||
|     dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, | ||||
|                        int* __restrict__ zeros, half* __restrict__ C, int G) { | ||||
|   int j_factors1 = 4; | ||||
|   int row_stride2 = 4; | ||||
|   int split_k_iters = 1; | ||||
| @ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( | ||||
|  | ||||
|   uint32_t B_loaded = *(uint32_t*)B_ptr2; | ||||
|   uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                : "=r"(B_loaded_fp16.x) | ||||
|                : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                : "=r"(B_loaded_fp16.x) | ||||
|                : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                : "=r"(B_loaded_fp16.y) | ||||
|                : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                : "=r"(B_loaded_fp16.y) | ||||
|                : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                : "=r"(B_loaded_fp16.z) | ||||
|                : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                : "=r"(B_loaded_fp16.z) | ||||
|                : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); | ||||
|   asm volatile("sub.f16x2 %0, %1, %2;\n" | ||||
|                : "=r"(B_loaded_fp16.w) | ||||
|                : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); | ||||
|   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" | ||||
|                : "=r"(B_loaded_fp16.w) | ||||
|                : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); | ||||
|  | ||||
|   *(uint4*)B_shared_ptr2 = B_loaded_fp16; | ||||
|  | ||||
| @ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights( | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace awq | ||||
| } // namespace vllm | ||||
| }  // namespace awq | ||||
| }  // namespace vllm | ||||
|  | ||||
| torch::Tensor awq_dequantize( | ||||
|     torch::Tensor _kernel, | ||||
|     torch::Tensor _scaling_factors, | ||||
|     torch::Tensor _zeros, | ||||
|     int split_k_iters, | ||||
|     int thx, | ||||
|     int thy) | ||||
| { | ||||
|     int in_c = _kernel.size(0); | ||||
|     int qout_c = _kernel.size(1); | ||||
|     int out_c = qout_c * 8; | ||||
|     int G = in_c / _scaling_factors.size(0); | ||||
| torch::Tensor awq_dequantize(torch::Tensor _kernel, | ||||
|                              torch::Tensor _scaling_factors, | ||||
|                              torch::Tensor _zeros, int split_k_iters, int thx, | ||||
|                              int thy) { | ||||
|   int in_c = _kernel.size(0); | ||||
|   int qout_c = _kernel.size(1); | ||||
|   int out_c = qout_c * 8; | ||||
|   int G = in_c / _scaling_factors.size(0); | ||||
|  | ||||
|     int x_thread = thx; | ||||
|     int y_thread = thy; | ||||
|   int x_thread = thx; | ||||
|   int y_thread = thy; | ||||
|  | ||||
|     int x_blocks = 1; | ||||
|     int y_blocks = 1; | ||||
|     if (thx==0) { | ||||
|       x_thread = qout_c; | ||||
|     } | ||||
|     if (thy==0) { | ||||
|       y_thread = in_c; | ||||
|     } | ||||
|     if (thx==0 && thy==0) { | ||||
|       x_thread = 8; | ||||
|       y_thread = 8; | ||||
|       x_blocks = (int)(qout_c / 8); | ||||
|       y_blocks = (int)(in_c / 8); | ||||
|     } | ||||
|   int x_blocks = 1; | ||||
|   int y_blocks = 1; | ||||
|   if (thx == 0) { | ||||
|     x_thread = qout_c; | ||||
|   } | ||||
|   if (thy == 0) { | ||||
|     y_thread = in_c; | ||||
|   } | ||||
|   if (thx == 0 && thy == 0) { | ||||
|     x_thread = 8; | ||||
|     y_thread = 8; | ||||
|     x_blocks = (int)(qout_c / 8); | ||||
|     y_blocks = (int)(in_c / 8); | ||||
|   } | ||||
|  | ||||
|     const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); | ||||
|  | ||||
|     auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); | ||||
|     at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); | ||||
|   auto options = torch::TensorOptions() | ||||
|                      .dtype(_scaling_factors.dtype()) | ||||
|                      .device(_scaling_factors.device()); | ||||
|   at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); | ||||
|  | ||||
|     auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); | ||||
|     auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>()); | ||||
|     auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); | ||||
|     auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); | ||||
|   auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); | ||||
|   auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>()); | ||||
|   auto scaling_factors = | ||||
|       reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); | ||||
|   auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); | ||||
|  | ||||
|     dim3 num_blocks(x_blocks, y_blocks); | ||||
|     dim3 threads_per_block(x_thread, y_thread); | ||||
|   dim3 num_blocks(x_blocks, y_blocks); | ||||
|   dim3 threads_per_block(x_thread, y_thread); | ||||
|  | ||||
|     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|     vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>( | ||||
|         kernel, scaling_factors, zeros, de_kernel, G); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>( | ||||
|       kernel, scaling_factors, zeros, de_kernel, G); | ||||
|  | ||||
|     return _de_kernel; | ||||
|   return _de_kernel; | ||||
| } | ||||
|  | ||||
| // in_feats: M, IC [float16] | ||||
| @ -386,61 +489,61 @@ torch::Tensor awq_dequantize( | ||||
| // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] | ||||
| // assume that batch_size < 16 for now | ||||
|  | ||||
| torch::Tensor awq_gemm( | ||||
|     torch::Tensor _in_feats, | ||||
|     torch::Tensor _kernel, | ||||
|     torch::Tensor _scaling_factors, | ||||
|     torch::Tensor _zeros, | ||||
|     int split_k_iters) | ||||
| { | ||||
|     int num_in_feats = _in_feats.size(0); | ||||
|     int num_in_channels = _in_feats.size(1); | ||||
|     const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); | ||||
| torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, | ||||
|                        torch::Tensor _scaling_factors, torch::Tensor _zeros, | ||||
|                        int split_k_iters) { | ||||
|   int num_in_feats = _in_feats.size(0); | ||||
|   int num_in_channels = _in_feats.size(1); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); | ||||
|  | ||||
|     auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); | ||||
|     at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); | ||||
|     int num_out_feats = _out_feats.size(-2); | ||||
|     int num_out_channels = _out_feats.size(-1); | ||||
|   auto options = torch::TensorOptions() | ||||
|                      .dtype(_in_feats.dtype()) | ||||
|                      .device(_in_feats.device()); | ||||
|   at::Tensor _out_feats = | ||||
|       torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); | ||||
|   int num_out_feats = _out_feats.size(-2); | ||||
|   int num_out_channels = _out_feats.size(-1); | ||||
|  | ||||
|     auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>()); | ||||
|     auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); | ||||
|     auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); | ||||
|     auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); | ||||
|     auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); | ||||
|     int group_size = num_in_channels / _scaling_factors.size(0); | ||||
|   auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>()); | ||||
|   auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); | ||||
|   auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); | ||||
|   auto scaling_factors = | ||||
|       reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); | ||||
|   auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); | ||||
|   int group_size = num_in_channels / _scaling_factors.size(0); | ||||
|  | ||||
|     if (num_out_channels % 64 != 0) | ||||
|         throw std::invalid_argument("OC is not multiple of cta_N = 64"); | ||||
|     if (num_out_channels % 8 != 0) | ||||
|         throw std::invalid_argument("OC is not multiple of pack_num = 8"); | ||||
|     if (group_size % 32 != 0) | ||||
| 	      throw std::invalid_argument("Group size should be a multiple of 32"); | ||||
|     if (num_out_channels % group_size != 0) | ||||
|         throw std::invalid_argument("OC is not multiple of Group size"); | ||||
|   if (num_out_channels % 64 != 0) | ||||
|     throw std::invalid_argument("OC is not multiple of cta_N = 64"); | ||||
|   if (num_out_channels % 8 != 0) | ||||
|     throw std::invalid_argument("OC is not multiple of pack_num = 8"); | ||||
|   if (group_size % 32 != 0) | ||||
|     throw std::invalid_argument("Group size should be a multiple of 32"); | ||||
|   if (num_out_channels % group_size != 0) | ||||
|     throw std::invalid_argument("OC is not multiple of Group size"); | ||||
|  | ||||
|     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|     if (num_out_channels % 128 == 0) | ||||
|     { | ||||
|         int j_factors1 = num_out_channels / 128 / 1; | ||||
|         dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); | ||||
|         // threadIdx.x: 32 | ||||
|         // threadIdx.y: i_factors[2] * j_factors[2] | ||||
|         dim3 threads_per_block(32, 2); | ||||
|         vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>( | ||||
|             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, | ||||
|             num_out_channels, out_feats); | ||||
|     } | ||||
|     else if (num_out_channels % 64 == 0) | ||||
|     { | ||||
|         int j_factors1 = num_out_channels / 64 / 1; | ||||
|         dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   if (num_out_channels % 128 == 0) { | ||||
|     int j_factors1 = num_out_channels / 128 / 1; | ||||
|     dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); | ||||
|     // threadIdx.x: 32 | ||||
|     // threadIdx.y: i_factors[2] * j_factors[2] | ||||
|     dim3 threads_per_block(32, 2); | ||||
|     vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> | ||||
|         <<<num_blocks, threads_per_block, 0, stream>>>( | ||||
|             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, | ||||
|             num_in_feats, num_in_channels, num_out_channels, out_feats); | ||||
|   } else if (num_out_channels % 64 == 0) { | ||||
|     int j_factors1 = num_out_channels / 64 / 1; | ||||
|     dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * | ||||
|                     split_k_iters); | ||||
|  | ||||
|         // threadIdx.x: 32 | ||||
|         // threadIdx.y: i_factors[2] * j_factors[2] | ||||
|         dim3 threads_per_block(32, 2); | ||||
|         vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>( | ||||
|             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, | ||||
|             num_out_channels, out_feats); | ||||
|     } | ||||
|     return _out_feats.sum(0); | ||||
|     // threadIdx.x: 32 | ||||
|     // threadIdx.y: i_factors[2] * j_factors[2] | ||||
|     dim3 threads_per_block(32, 2); | ||||
|     vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> | ||||
|         <<<num_blocks, threads_per_block, 0, stream>>>( | ||||
|             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, | ||||
|             num_in_feats, num_in_channels, num_out_channels, out_feats); | ||||
|   } | ||||
|   return _out_feats.sum(0); | ||||
| } | ||||
|  | ||||
							
								
								
									
										62
									
								
								csrc/quantization/compressed_tensors/int8_quant_kernels.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								csrc/quantization/compressed_tensors/int8_quant_kernels.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
| #include <torch/extension.h> | ||||
| #include <cmath> | ||||
|  | ||||
| #include "../../dispatch_utils.h" | ||||
|  | ||||
| static inline __device__ int8_t float_to_int8_rn(float x) { | ||||
| #ifdef USE_ROCM | ||||
|   static const float i8_min = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::min()); | ||||
|   static const float i8_max = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::max()); | ||||
|   // round | ||||
|   float dst = std::nearbyint(x); | ||||
|   // saturate | ||||
|   dst = std::clamp(dst, i8_min, i8_max); | ||||
|   return static_cast<int8_t>(dst); | ||||
| #else | ||||
|   // CUDA path | ||||
|   uint32_t dst; | ||||
|   asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); | ||||
|   return reinterpret_cast<const int8_t&>(dst); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| template <typename scalar_t, typename scale_type> | ||||
| __global__ void static_scaled_int8_quant_kernel( | ||||
|     const scalar_t* __restrict__ input, int8_t* __restrict__ out, | ||||
|     const scale_type* scale_ptr, const int hidden_size) { | ||||
|   const int tid = threadIdx.x; | ||||
|   const int token_idx = blockIdx.x; | ||||
|   scale_type scale = *scale_ptr; | ||||
|  | ||||
|   for (int i = tid; i < hidden_size; i += blockDim.x) { | ||||
|     out[token_idx * hidden_size + i] = | ||||
|         float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); | ||||
|   } | ||||
| } | ||||
| }  // namespace vllm | ||||
|  | ||||
| void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size] | ||||
|                               torch::Tensor const& input,  // [..., hidden_size] | ||||
|                               torch::Tensor const& scale) { | ||||
|   TORCH_CHECK(input.is_contiguous()); | ||||
|   TORCH_CHECK(out.is_contiguous()); | ||||
|   TORCH_CHECK(scale.numel() == 1); | ||||
|  | ||||
|   int hidden_size = input.size(-1); | ||||
|   int num_tokens = input.numel() / hidden_size; | ||||
|   dim3 grid(num_tokens); | ||||
|   dim3 block(std::min(hidden_size, 1024)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { | ||||
|         vllm::static_scaled_int8_quant_kernel<scalar_t, float> | ||||
|             <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), | ||||
|                                          out.data_ptr<int8_t>(), | ||||
|                                          scale.data_ptr<float>(), hidden_size); | ||||
|       }); | ||||
| } | ||||
							
								
								
									
										346
									
								
								csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										346
									
								
								csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,346 @@ | ||||
| /*************************************************************************************************** | ||||
|  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights | ||||
|  *reserved. SPDX-License-Identifier: BSD-3-Clause | ||||
|  * | ||||
|  * Redistribution and use in source and binary forms, with or without | ||||
|  * modification, are permitted provided that the following conditions are met: | ||||
|  * | ||||
|  * 1. Redistributions of source code must retain the above copyright notice, | ||||
|  *this list of conditions and the following disclaimer. | ||||
|  * | ||||
|  * 2. Redistributions in binary form must reproduce the above copyright notice, | ||||
|  * this list of conditions and the following disclaimer in the documentation | ||||
|  * and/or other materials provided with the distribution. | ||||
|  * | ||||
|  * 3. Neither the name of the copyright holder nor the names of its | ||||
|  * contributors may be used to endorse or promote products derived from | ||||
|  * this software without specific prior written permission. | ||||
|  * | ||||
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||
|  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||||
|  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||||
|  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||||
|  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||||
|  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||||
|  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||||
|  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||||
|  *POSSIBILITY OF SUCH DAMAGE. | ||||
|  * | ||||
|  **************************************************************************************************/ | ||||
|  | ||||
| // | ||||
| // This file is a modified excerpt of | ||||
| // include/cutlass/epilogue/fusion/visitor_load.hpp from | ||||
| // https://github.com/NVIDIA/cutlass v3.5.0 | ||||
| // It has been modified to support either | ||||
| // row/column or scalar broadcasting where the tensor being loaded from is | ||||
| // always passed in via a device pointer. This lets one compiled kernel handle | ||||
| // all cases of per-tensor or per-channel/per-token quantization. | ||||
| // | ||||
| // This interface also allows the scales to be passed in as tensors that | ||||
| // consistently reside on the device, which avoids an issue with a previous | ||||
| // implementation where scalars needed to be on the CPU since they | ||||
| // were passed in via float values. This created a potential performance hazard | ||||
| // if scales were initially on the device, and caused torch.compile graph | ||||
| // breaks when moving scales to the CPU. | ||||
| // | ||||
| #pragma once | ||||
|  | ||||
| // Turn off clang-format for the entire file to keep it close to upstream | ||||
| // clang-format off | ||||
|  | ||||
| #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" | ||||
| #include "cute/tensor.hpp" | ||||
|  | ||||
| namespace cutlass::epilogue::threadblock { | ||||
|  | ||||
| using namespace cute; | ||||
| using namespace detail; | ||||
|  | ||||
| template< | ||||
|   class ThreadMap, | ||||
|   class Element, | ||||
|   class StrideMNL | ||||
| > | ||||
| struct VisitorRowOrScalarBroadcast { | ||||
|  | ||||
|   // This struct has been modified to have a bool indicating that ptr_row is a  | ||||
|   // scalar that must be broadcast. | ||||
|   struct Arguments { | ||||
|     Element const* ptr_row = nullptr; | ||||
|     bool row_broadcast = true; | ||||
|     StrideMNL dRow = {}; | ||||
|   }; | ||||
|  | ||||
|   using Params = Arguments; | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static constexpr Params | ||||
|   to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { | ||||
|     return args; | ||||
|   } | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static size_t | ||||
|   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   struct SharedStorage {}; | ||||
|  | ||||
|   // Global load type | ||||
|   static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value; | ||||
|   using VecType = uint_bit_t<cute::min(128, vec_bits)>; | ||||
|   static int constexpr VecLength = sizeof(VecType) / sizeof(Element); | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   VisitorRowOrScalarBroadcast() { } | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) | ||||
|     : params_ptr(¶ms) { } | ||||
|  | ||||
|   Params const* params_ptr; | ||||
|  | ||||
|   template <class GTensor, class RTensor, class CTensor, class ProblemShape> | ||||
|   struct Callbacks : EmptyCallbacks { | ||||
|     CUTLASS_DEVICE | ||||
|     Callbacks( | ||||
|       GTensor&& tC_gRow, | ||||
|       RTensor&& tC_rRow, | ||||
|       CTensor&& tC_cRow, | ||||
|       ProblemShape problem_shape, | ||||
|       Params const* params_ptr | ||||
|     ): | ||||
|       tC_gRow(cute::forward<GTensor>(tC_gRow)), | ||||
|       tC_rRow(cute::forward<RTensor>(tC_rRow)), | ||||
|       tC_cRow(cute::forward<CTensor>(tC_cRow)), | ||||
|       n(get<1>(problem_shape)), | ||||
|       params_ptr(params_ptr) { } | ||||
|  | ||||
|     GTensor tC_gRow; | ||||
|     RTensor tC_rRow; | ||||
|     CTensor tC_cRow; | ||||
|     Params const* params_ptr; | ||||
|     int n; | ||||
|  | ||||
|     // This function is modified from VisitorRowBroadcast | ||||
|     CUTLASS_DEVICE void | ||||
|     begin_epilogue() { | ||||
|       clear(tC_rRow); | ||||
|       auto src_v = filter(tC_gRow); | ||||
|       auto coord_v = filter(tC_cRow); | ||||
|       auto dst_v = filter(tC_rRow); | ||||
|  | ||||
|       if (params_ptr->row_broadcast) { | ||||
|         // In this case we are loading from a row vector and broadcasting | ||||
|         CUTLASS_PRAGMA_UNROLL | ||||
|         for (int i = 0; i < size(src_v); ++i) { | ||||
|           bool guard = get<1>(coord_v(i)) < n; | ||||
|           cutlass::arch::global_load<VecType, sizeof(VecType)>( | ||||
|               dst_v(i), (void const*)&src_v(i), guard); | ||||
|         } | ||||
|       } else { | ||||
|         // In this case we are loading from a scalar and broadcasting | ||||
|         VecType filled_vec; | ||||
|         CUTLASS_PRAGMA_UNROLL | ||||
|         for (int i = 0; i < VecLength; i++) { | ||||
|           reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row); | ||||
|         } | ||||
|  | ||||
|         CUTLASS_PRAGMA_UNROLL | ||||
|         for (int i = 0; i < size(src_v); ++i) { | ||||
|           if (get<1>(coord_v(i)) < n) { | ||||
|             dst_v(i) = filled_vec; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     template <class ElementAccumulator, int FragmentSize> | ||||
|     CUTLASS_DEVICE auto // returns an Array | ||||
|     visit(int iter_idx, int row_idx, int column_idx, int frg_idx, | ||||
|           Array<ElementAccumulator, FragmentSize> const& frg_acc) { | ||||
|       Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow)); | ||||
|       return rRow_frg(column_idx); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   CUTLASS_DEVICE auto | ||||
|   get_callbacks( | ||||
|     gemm::GemmCoord threadblock_tile_offset, | ||||
|     int thread_idx, | ||||
|     ProblemShape problem_shape | ||||
|   ) { | ||||
|     Tensor mRow = make_tensor( | ||||
|       make_gmem_ptr(params_ptr->ptr_row), | ||||
|       problem_shape, | ||||
|       params_ptr->dRow); | ||||
|  | ||||
|     // VECTOR, FRAGMENT_COLUMN | ||||
|     Tensor tC_gRow = recast<VecType>( | ||||
|       ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) | ||||
|     )(_,_,_0{},_0{},_0{},_0{}); | ||||
|     Tensor tC_rRow = make_tensor_like(tC_gRow); | ||||
|  | ||||
|     // Generate the pred tensor | ||||
|     Tensor cRow = make_identity_tensor(mRow.shape()); | ||||
|     Tensor tC_cRow = outer_partition( | ||||
|       ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), | ||||
|       Shape<Int<VecLength>>{}, | ||||
|       (_0{}) | ||||
|     ); | ||||
|  | ||||
|     return Callbacks< | ||||
|       decltype(tC_gRow), decltype(tC_rRow), | ||||
|       decltype(tC_cRow), ProblemShape>( | ||||
|       cute::move(tC_gRow), | ||||
|       cute::move(tC_rRow), | ||||
|       cute::move(tC_cRow), | ||||
|       problem_shape, | ||||
|       params_ptr | ||||
|     ); | ||||
|   } | ||||
|  | ||||
| }; | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| // Column vector broadcast | ||||
| template< | ||||
|   class ThreadMap, | ||||
|   class Element, | ||||
|   class StrideMNL = Stride<_1,_0,_0> | ||||
| > | ||||
| struct VisitorColOrScalarBroadcast { | ||||
|  | ||||
|   // This struct has been modified to have a bool indicating that ptr_col is a  | ||||
|   // scalar that must be broadcast. | ||||
|   struct Arguments { | ||||
|     Element const* ptr_col = nullptr; | ||||
|     bool col_broadcast = true; | ||||
|     StrideMNL dCol = {}; | ||||
|   }; | ||||
|  | ||||
|   using Params = Arguments; | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static constexpr Params | ||||
|   to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { | ||||
|     return args; | ||||
|   } | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static size_t | ||||
|   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   struct SharedStorage { }; | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   VisitorColOrScalarBroadcast() { } | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) | ||||
|     : params_ptr(¶ms) { } | ||||
|  | ||||
|   Params const* params_ptr; | ||||
|  | ||||
|   template <class GTensor, class RTensor, class CTensor, class ProblemShape> | ||||
|   struct Callbacks : EmptyCallbacks { | ||||
|     CUTLASS_DEVICE | ||||
|     Callbacks( | ||||
|       GTensor&& tC_gCol, | ||||
|       RTensor&& tC_rCol, | ||||
|       CTensor&& tC_cCol, | ||||
|       ProblemShape problem_shape, | ||||
|       Params const* params_ptr | ||||
|     ): | ||||
|       tC_gCol(cute::forward<GTensor>(tC_gCol)), | ||||
|       tC_rCol(cute::forward<RTensor>(tC_rCol)), | ||||
|       tC_cCol(cute::forward<CTensor>(tC_cCol)), | ||||
|       m(get<0>(problem_shape)), | ||||
|       params_ptr(params_ptr) { } | ||||
|  | ||||
|     GTensor tC_gCol; | ||||
|     RTensor tC_rCol; | ||||
|     CTensor tC_cCol; | ||||
|     Params const* params_ptr; | ||||
|     int m; | ||||
|  | ||||
|     // This function is modified from VisitorColBroadcast | ||||
|     CUTLASS_DEVICE void  | ||||
|     begin_epilogue() { | ||||
|       clear(tC_rCol); | ||||
|  | ||||
|       Tensor pred = make_tensor<bool>(shape(tC_gCol)); | ||||
|       CUTLASS_PRAGMA_UNROLL | ||||
|       for (int i = 0; i < size(pred); ++i) { | ||||
|         pred(i) = get<0>(tC_cCol(i)) < m; | ||||
|       } | ||||
|  | ||||
|       if (params_ptr->col_broadcast) { | ||||
|         // In this case we are loading from a column vector and broadcasting | ||||
|         copy_if(pred, tC_gCol, tC_rCol); | ||||
|       } else { | ||||
|         // In this case we are loading from a scalar and broadcasting | ||||
|         auto dst_v = filter(tC_rCol); | ||||
|  | ||||
|         CUTLASS_PRAGMA_UNROLL | ||||
|         for (int i = 0; i < size(dst_v); ++i) { | ||||
|           if (pred(i)) { | ||||
|             dst_v(i) = *(params_ptr->ptr_col); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     template <class ElementAccumulator, int FragmentSize> | ||||
|     CUTLASS_DEVICE auto // returns an Array | ||||
|     visit(int iter_idx, int row_idx, int column_idx, int frg_idx, | ||||
|           Array<ElementAccumulator, FragmentSize> const& frg_acc) { | ||||
|       Array<Element, FragmentSize> frg_col; | ||||
|       frg_col.fill(tC_rCol(row_idx,iter_idx)); | ||||
|       return frg_col; | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   CUTLASS_DEVICE auto | ||||
|   get_callbacks( | ||||
|     gemm::GemmCoord threadblock_tile_offset, | ||||
|     int thread_idx, | ||||
|     ProblemShape problem_shape | ||||
|   ) { | ||||
|     Tensor mCol = make_tensor( | ||||
|       make_gmem_ptr(params_ptr->ptr_col), | ||||
|       problem_shape, | ||||
|       params_ptr->dCol); | ||||
|  | ||||
|     // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER | ||||
|     Tensor tC_gCol = group_modes<1,4>( | ||||
|       ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); | ||||
|     Tensor tC_rCol = make_tensor_like(tC_gCol); | ||||
|  | ||||
|     // Generate the pred tensor | ||||
|     Tensor cCol = make_identity_tensor(mCol.shape()); | ||||
|     Tensor tC_cCol = group_modes<1,4>( | ||||
|       ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); | ||||
|  | ||||
|     return Callbacks< | ||||
|       decltype(tC_gCol), decltype(tC_rCol), | ||||
|       decltype(tC_cCol), ProblemShape>( | ||||
|       cute::move(tC_gCol), | ||||
|       cute::move(tC_rCol), | ||||
|       cute::move(tC_cCol), | ||||
|       problem_shape, | ||||
|       params_ptr | ||||
|     ); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } | ||||
							
								
								
									
										389
									
								
								csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										389
									
								
								csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,389 @@ | ||||
| /*************************************************************************************************** | ||||
|  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights | ||||
|  *reserved. SPDX-License-Identifier: BSD-3-Clause | ||||
|  * | ||||
|  * Redistribution and use in source and binary forms, with or without | ||||
|  * modification, are permitted provided that the following conditions are met: | ||||
|  * | ||||
|  * 1. Redistributions of source code must retain the above copyright notice, | ||||
|  *this list of conditions and the following disclaimer. | ||||
|  * | ||||
|  * 2. Redistributions in binary form must reproduce the above copyright notice, | ||||
|  * this list of conditions and the following disclaimer in the documentation | ||||
|  * and/or other materials provided with the distribution. | ||||
|  * | ||||
|  * 3. Neither the name of the copyright holder nor the names of its | ||||
|  * contributors may be used to endorse or promote products derived from | ||||
|  * this software without specific prior written permission. | ||||
|  * | ||||
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||
|  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||||
|  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||||
|  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||||
|  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||||
|  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||||
|  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||||
|  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||||
|  *POSSIBILITY OF SUCH DAMAGE. | ||||
|  * | ||||
|  **************************************************************************************************/ | ||||
|  | ||||
| // | ||||
| // This file is a modified excerpt of | ||||
| // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp | ||||
| // from https://github.com/NVIDIA/cutlass v3.5.0 | ||||
| // It has been modified to support either row/column or scalar broadcasting | ||||
| // where the tensor being loaded from is always passed in via a device pointer. | ||||
| // This lets one compiled kernel handle all cases of per-tensor or | ||||
| // per-channel/per-token quantization. | ||||
| // | ||||
| // This interface also allows the scales to be passed in as tensors that | ||||
| // consistently reside on the device, which avoids an issue with a previous | ||||
| // implementation where scalars needed to be on the CPU since they | ||||
| // were passed in via float values. This created a potential performance hazard | ||||
| // if scales were initially on the device, and caused torch.compile graphs | ||||
| // breaks when moving scales to the CPU. | ||||
| // | ||||
| #pragma once | ||||
|  | ||||
| // Turn off clang-format for the entire file to keep it close to upstream | ||||
| // clang-format off | ||||
|  | ||||
| #include "cutlass/cutlass.h" | ||||
| #include "cutlass/arch/barrier.h" | ||||
|  | ||||
| #include "cute/tensor.hpp" | ||||
| #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" | ||||
|  | ||||
| namespace cutlass::epilogue::fusion { | ||||
|  | ||||
| using namespace cute; | ||||
| using namespace detail; | ||||
|  | ||||
| // Row vector broadcast | ||||
| template< | ||||
|   // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least | ||||
|   // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races | ||||
|   int Stages, | ||||
|   class CtaTileShapeMNK, | ||||
|   class Element, | ||||
|   class StrideMNL = Stride<_0,_1,_0>, | ||||
|   int Alignment = 128 / sizeof_bits_v<Element> | ||||
| > | ||||
| struct Sm90RowOrScalarBroadcast { | ||||
|   static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet"); | ||||
|   static_assert( | ||||
|     (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias | ||||
|     (cute::is_same_v<StrideMNL, Stride<_0,_1,int>>));  // batched row vector broadcast | ||||
|  | ||||
|   // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem | ||||
|   struct SharedStorage { | ||||
|     alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row; | ||||
|   }; | ||||
|  | ||||
|   // This struct has been modified to have a bool indicating that ptr_row is a  | ||||
|   // scalar that must be broadcast, instead of containing a scalar that is  | ||||
|   // valid if ptr_row is null. | ||||
|   struct Arguments { | ||||
|     Element const* ptr_row = nullptr; | ||||
|     bool row_broadcast = true; | ||||
|     StrideMNL dRow = {}; | ||||
|   }; | ||||
|  | ||||
|   using Params = Arguments; | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static constexpr Params | ||||
|   to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { | ||||
|     return args; | ||||
|   } | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static size_t | ||||
|   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static cutlass::Status | ||||
|   initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, | ||||
|     CudaHostAdapter* cuda_adapter = nullptr) { | ||||
|     return cutlass::Status::kSuccess; | ||||
|   } | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   Sm90RowOrScalarBroadcast() { } | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) | ||||
|       : params(params), | ||||
|         smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { } | ||||
|  | ||||
|   Params params; | ||||
|   Element* smem_row; | ||||
|  | ||||
|   CUTLASS_DEVICE bool | ||||
|   is_producer_load_needed() const { | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   CUTLASS_DEVICE bool | ||||
|   is_C_load_needed() const { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   CUTLASS_DEVICE bool | ||||
|   is_zero() const { | ||||
|     return (!params.row_broadcast && *(params.ptr_row) == Element(0)); | ||||
|   } | ||||
|  | ||||
|   template <int EpiTiles, class GTensor, class STensor> | ||||
|   struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { | ||||
|     CUTLASS_DEVICE | ||||
|     ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) | ||||
|       : gRow(cute::forward<GTensor>(gRow)), | ||||
|         sRow(cute::forward<STensor>(sRow)), | ||||
|         params(params) {} | ||||
|  | ||||
|     GTensor gRow;                                                                                 // (CTA_M,CTA_N) | ||||
|     STensor sRow;                                                                                 // (CTA_M,CTA_N,PIPE) | ||||
|     Params const& params; | ||||
|  | ||||
|     CUTLASS_DEVICE void | ||||
|     begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { | ||||
|       if (params.ptr_row == nullptr) { | ||||
|         return; | ||||
|       } | ||||
|  | ||||
|       if (issue_tma_load) { | ||||
|         // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size | ||||
|         constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8; | ||||
|         cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); | ||||
|         // Issue the TMA bulk copy | ||||
|         auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr); | ||||
|         // Filter so we don't issue redundant copies over stride-0 modes | ||||
|         int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; | ||||
|         copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); | ||||
|       } | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   template <class... Args> | ||||
|   CUTLASS_DEVICE auto | ||||
|   get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { | ||||
|  | ||||
|     auto [M, N, K, L] = args.problem_shape_mnkl; | ||||
|     auto [m, n, k, l] = args.tile_coord_mnkl; | ||||
|     Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); | ||||
|     Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l));            // (CTA_M,CTA_N) | ||||
|     Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE) | ||||
|                     make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), | ||||
|                     make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); | ||||
|  | ||||
|     constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; | ||||
|     return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>( | ||||
|       cute::move(gRow), cute::move(sRow), params); | ||||
|   } | ||||
|  | ||||
|   template <int EpiTiles, class RTensor, class STensor> | ||||
|   struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { | ||||
|     CUTLASS_DEVICE | ||||
|     ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) | ||||
|       : tCrRow(cute::forward<RTensor>(tCrRow)), | ||||
|         tCsRow(cute::forward<STensor>(tCsRow)), | ||||
|         params(params) {} | ||||
|  | ||||
|     RTensor tCrRow;                                                               // (CPY,CPY_M,CPY_N) | ||||
|     STensor tCsRow;                                                               // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) | ||||
|     Params const& params; | ||||
|  | ||||
|     CUTLASS_DEVICE void | ||||
|     previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { | ||||
|       if (!params.row_broadcast) { | ||||
|         fill(tCrRow, *(params.ptr_row)); | ||||
|         return; | ||||
|       } | ||||
|  | ||||
|       if (epi_m == 0) { // Assumes M-major subtile loop | ||||
|         // Filter so we don't issue redundant copies over stride-0 modes | ||||
|         // (only works if 0-strides are in same location, which is by construction) | ||||
|         int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; | ||||
|         copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     template <typename ElementAccumulator, int FragmentSize> | ||||
|     CUTLASS_DEVICE Array<Element, FragmentSize> | ||||
|     visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { | ||||
|       Array<Element, FragmentSize> frg_row; | ||||
|  | ||||
|       CUTLASS_PRAGMA_UNROLL | ||||
|       for (int i = 0; i < FragmentSize; ++i) { | ||||
|         frg_row[i] = tCrRow(epi_v * FragmentSize + i); | ||||
|       } | ||||
|  | ||||
|       return frg_row; | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   template < | ||||
|     bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy | ||||
|     class... Args | ||||
|   > | ||||
|   CUTLASS_DEVICE auto | ||||
|   get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { | ||||
|  | ||||
|     Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE) | ||||
|                     make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), | ||||
|                     make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); | ||||
|     Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>(                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) | ||||
|                       sRow, args.epi_tile, args.tiled_copy, args.thread_idx); | ||||
|     Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow));                                           // (CPY,CPY_M,CPY_N) | ||||
|  | ||||
|     constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; | ||||
|     return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>( | ||||
|       cute::move(tCrRow), cute::move(tCsRow), params); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| // Column vector broadcast | ||||
| template< | ||||
|   int Stages, | ||||
|   class CtaTileShapeMNK, | ||||
|   class Element, | ||||
|   class StrideMNL = Stride<_1,_0,_0>, | ||||
|   int Alignment = 128 / sizeof_bits_v<Element> | ||||
| > | ||||
| struct Sm90ColOrScalarBroadcast { | ||||
|   static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); | ||||
|   static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet"); | ||||
|   static_assert( | ||||
|     (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias | ||||
|     (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>));  // batched col vector broadcast, e.g. batched per-row bias | ||||
|  | ||||
|   // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem | ||||
|   struct SharedStorage { }; | ||||
|  | ||||
|   // This struct has been modified to have a bool indicating that ptr_col is a  | ||||
|   // scalar that must be broadcast, instead of containing a scalar that is  | ||||
|   // valid if ptr_col is null. | ||||
|   struct Arguments { | ||||
|     Element const* ptr_col = nullptr; | ||||
|     bool col_broadcast = true; | ||||
|     StrideMNL dCol = {}; | ||||
|   }; | ||||
|  | ||||
|   using Params = Arguments; | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static constexpr Params | ||||
|   to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { | ||||
|     return args; | ||||
|   } | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static size_t | ||||
|   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   template <class ProblemShape> | ||||
|   static cutlass::Status | ||||
|   initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, | ||||
|     CudaHostAdapter* cuda_adapter = nullptr) { | ||||
|     return cutlass::Status::kSuccess; | ||||
|   } | ||||
|  | ||||
|   CUTLASS_DEVICE bool | ||||
|   is_producer_load_needed() const { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   CUTLASS_DEVICE bool | ||||
|   is_C_load_needed() const { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   CUTLASS_DEVICE bool | ||||
|   is_zero() const { | ||||
|     return (!params.col_broadcast && *(params.ptr_col) == Element(0)); | ||||
|   } | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   Sm90ColOrScalarBroadcast() { } | ||||
|  | ||||
|   CUTLASS_HOST_DEVICE | ||||
|   Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) | ||||
|       : params(params) { } | ||||
|  | ||||
|   Params params; | ||||
|  | ||||
|   template <class... Args> | ||||
|   CUTLASS_DEVICE auto | ||||
|   get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { | ||||
|     return EmptyProducerLoadCallbacks{}; | ||||
|   } | ||||
|  | ||||
|   template<class GTensor, class RTensor> | ||||
|   struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { | ||||
|     CUTLASS_DEVICE | ||||
|     ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) | ||||
|       : tCgCol(cute::forward<GTensor>(tCgCol)), | ||||
|         tCrCol(cute::forward<RTensor>(tCrCol)), | ||||
|         params(params) {} | ||||
|  | ||||
|     GTensor tCgCol;                                                                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||
|     RTensor tCrCol;                                                                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||
|     Params const& params; | ||||
|  | ||||
|     CUTLASS_DEVICE void | ||||
|     begin() { | ||||
|       if (!params.col_broadcast) { | ||||
|         fill(tCrCol, *(params.ptr_col)); | ||||
|         return; | ||||
|       } | ||||
|  | ||||
|       // Filter so we don't issue redundant copies over stride-0 modes | ||||
|       // (only works if 0-strides are in same location, which is by construction) | ||||
|       copy_aligned(filter(tCgCol), filter(tCrCol)); | ||||
|     } | ||||
|  | ||||
|     template <typename ElementAccumulator, int FragmentSize> | ||||
|     CUTLASS_DEVICE Array<Element, FragmentSize> | ||||
|     visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { | ||||
|       Array<Element, FragmentSize> frg_col; | ||||
|       Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); | ||||
|  | ||||
|       CUTLASS_PRAGMA_UNROLL | ||||
|       for (int i = 0; i < FragmentSize; ++i) { | ||||
|         frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); | ||||
|       } | ||||
|  | ||||
|       return frg_col; | ||||
|     } | ||||
|  | ||||
|   }; | ||||
|  | ||||
|   template < | ||||
|     bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy | ||||
|     class... Args | ||||
|   > | ||||
|   CUTLASS_DEVICE auto | ||||
|   get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { | ||||
|  | ||||
|     auto [M, N, K, L] = args.problem_shape_mnkl; | ||||
|     Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); | ||||
|     Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>(                         // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||
|       mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); | ||||
|     Tensor tCrCol = make_tensor_like(tCgCol);                                          // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||
|  | ||||
|     return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>( | ||||
|       cute::move(tCgCol), cute::move(tCrCol), params); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } | ||||
							
								
								
									
										12
									
								
								csrc/quantization/cutlass_w8a8/common.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								csrc/quantization/cutlass_w8a8/common.hpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "cutlass/cutlass.h" | ||||
|  | ||||
| /** | ||||
|  * Helper function for checking CUTLASS errors | ||||
|  */ | ||||
| #define CUTLASS_CHECK(status)                        \ | ||||
|   {                                                  \ | ||||
|     TORCH_CHECK(status == cutlass::Status::kSuccess, \ | ||||
|                 cutlassGetStatusString(status))      \ | ||||
|   } | ||||
							
								
								
									
										294
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,294 @@ | ||||
| #include <stddef.h> | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
|  | ||||
| // clang-format will break include orders | ||||
| // clang-format off | ||||
| #include "cute/tensor.hpp" | ||||
| #include "cute/atom/mma_atom.hpp" | ||||
| #include "cutlass/numeric_types.h" | ||||
|  | ||||
| #include "cutlass/util/device_memory.h" | ||||
|  | ||||
| #include "cutlass/cutlass.h" | ||||
| #include "cutlass/gemm_coord.h" | ||||
| #include "cutlass/arch/mma_sm75.h" | ||||
| #include "cutlass/arch/arch.h" | ||||
| #include "cutlass/arch/mma.h" | ||||
| #include "cutlass/gemm/device/gemm.h" | ||||
| #include "cutlass/gemm/device/gemm_universal_adapter.h" | ||||
|  | ||||
| #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" | ||||
| #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" | ||||
|  | ||||
| #include "broadcast_load_epilogue_c2x.hpp" | ||||
| #include "common.hpp" | ||||
| // clang-format on | ||||
|  | ||||
| using namespace cute; | ||||
|  | ||||
| /* | ||||
|    This defines a quantized GEMM operation with dequantized output, similar to | ||||
|    torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for | ||||
|    NVIDIA GPUs with SM versions prior to sm90 (Hopper). | ||||
|  | ||||
|    A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or | ||||
|    per-row. B can be quantized per-tensor or per-column. | ||||
|    Any combination of per-tensor and per-row or column is supported. | ||||
|    A and B must have symmetric quantization (zero point == 0). | ||||
|  | ||||
|    So the GEMM operation is D = (a_scales * A) (b_scales * B), where the | ||||
|    scales are applied elementwise with numpy-style broadcasting. | ||||
|  | ||||
|    ScaleA and ScaleB define the epilogue functions that apply the scales for | ||||
|    the A and B operands respectively. These scales may be either per-tensor or | ||||
|    per row or column. | ||||
| */ | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename Arch, typename ElementAB_, typename ElementD_, | ||||
|           typename TileShape, typename WarpShape, typename InstructionShape, | ||||
|           int32_t MainLoopStages> | ||||
| struct cutlass_2x_gemm { | ||||
|   using ElementAB = ElementAB_; | ||||
|   using ElementD = ElementD_; | ||||
|  | ||||
|   using ElementAcc = | ||||
|       typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, | ||||
|                                 float>::type; | ||||
|  | ||||
|   using Operator = | ||||
|       typename std::conditional<std::is_same_v<ElementAB, int8_t>, | ||||
|                                 cutlass::arch::OpMultiplyAddSaturate, | ||||
|                                 cutlass::arch::OpMultiplyAdd>::type; | ||||
|  | ||||
|   using OutputTileThreadMap = | ||||
|       cutlass::epilogue::threadblock::OutputTileThreadLayout< | ||||
|           TileShape, WarpShape, float, 4, 1 /* epilogue stages */ | ||||
|           >; | ||||
|  | ||||
|   using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; | ||||
|  | ||||
|   using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< | ||||
|       OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>; | ||||
|  | ||||
|   using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< | ||||
|       OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>; | ||||
|  | ||||
|   using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< | ||||
|       cutlass::multiplies, float, float, | ||||
|       cutlass::FloatRoundStyle::round_to_nearest>; | ||||
|  | ||||
|   using EVTCompute0 = | ||||
|       cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>; | ||||
|  | ||||
|   using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< | ||||
|       cutlass::multiplies, ElementD, float, | ||||
|       cutlass::FloatRoundStyle::round_to_nearest>; | ||||
|  | ||||
|   using EVTCompute1 = | ||||
|       cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>; | ||||
|  | ||||
|   using D = cutlass::epilogue::threadblock::VisitorAuxStore< | ||||
|       OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, | ||||
|       Stride<int64_t, Int<1>, Int<0>>>; | ||||
|  | ||||
|   using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>; | ||||
|  | ||||
|   // clang-format off | ||||
|   using RowMajor = typename cutlass::layout::RowMajor; | ||||
|   using ColumnMajor = typename cutlass::layout::ColumnMajor; | ||||
|   using KernelType =  | ||||
|     typename cutlass::gemm::kernel::DefaultGemmWithVisitor< | ||||
|       ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,  | ||||
|       ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,  | ||||
|       float, cutlass::layout::RowMajor, 4, | ||||
|       ElementAcc, float, cutlass::arch::OpClassTensorOp,  | ||||
|       Arch,  | ||||
|       TileShape, WarpShape, InstructionShape, | ||||
|       EVTD, | ||||
|       cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, | ||||
|       MainLoopStages, Operator, | ||||
|       1 /* epilogue stages */ | ||||
|       >::GemmKernel; | ||||
|   // clang-format on | ||||
|  | ||||
|   using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>; | ||||
| }; | ||||
|  | ||||
| template <typename Gemm> | ||||
| void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                      torch::Tensor const& b, | ||||
|                                      torch::Tensor const& a_scales, | ||||
|                                      torch::Tensor const& b_scales) { | ||||
|   using ElementAB = typename Gemm::ElementAB; | ||||
|   using ElementD = typename Gemm::ElementD; | ||||
|  | ||||
|   int32_t m = a.size(0); | ||||
|   int32_t n = b.size(1); | ||||
|   int32_t k = a.size(1); | ||||
|   cutlass::gemm::GemmCoord problem_size{m, n, k}; | ||||
|  | ||||
|   int64_t lda = a.stride(0); | ||||
|   int64_t ldb = b.stride(1); | ||||
|   int64_t ldc = out.stride(0); | ||||
|  | ||||
|   using StrideC = Stride<int64_t, Int<1>, Int<0>>; | ||||
|   StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; | ||||
|  | ||||
|   auto a_ptr = static_cast<ElementAB const*>(a.data_ptr()); | ||||
|   auto b_ptr = static_cast<ElementAB const*>(b.data_ptr()); | ||||
|   auto c_ptr = static_cast<ElementD*>(out.data_ptr()); | ||||
|  | ||||
|   auto a_scales_ptr = a_scales.data_ptr<float>(); | ||||
|   auto b_scales_ptr = b_scales.data_ptr<float>(); | ||||
|  | ||||
|   using ScaleAArgs = typename Gemm::ScaleA::Arguments; | ||||
|   using ScaleBArgs = typename Gemm::ScaleB::Arguments; | ||||
|  | ||||
|   ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}}; | ||||
|   ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}}; | ||||
|  | ||||
|   typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; | ||||
|  | ||||
|   typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, | ||||
|                                                           evt0_compute_args}; | ||||
|   typename Gemm::D::Arguments d_args{c_ptr, c_stride}; | ||||
|  | ||||
|   typename Gemm::EVTD::Arguments epilogue_args{ | ||||
|       evt1_compute_args, | ||||
|       d_args, | ||||
|   }; | ||||
|  | ||||
|   typename Gemm::Op::Arguments args{ | ||||
|       cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel,  // universal mode | ||||
|       problem_size,                                           // problem size | ||||
|       1,                                                      // batch count | ||||
|       epilogue_args, | ||||
|       a_ptr, | ||||
|       b_ptr, | ||||
|       nullptr, | ||||
|       nullptr, | ||||
|       0, | ||||
|       0, | ||||
|       0, | ||||
|       0, | ||||
|       lda, | ||||
|       ldb, | ||||
|       ldc, | ||||
|       ldc}; | ||||
|  | ||||
|   // Launch the CUTLASS GEMM kernel. | ||||
|   typename Gemm::Op gemm_op; | ||||
|   size_t workspace_size = gemm_op.get_workspace_size(args); | ||||
|   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | ||||
|  | ||||
|   auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); | ||||
|  | ||||
|   CUTLASS_CHECK(gemm_op.can_implement(args)); | ||||
|   cutlass::Status status = gemm_op(args, workspace.get(), stream); | ||||
|   CUTLASS_CHECK(status); | ||||
| } | ||||
|  | ||||
| }  // namespace | ||||
|  | ||||
| void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales) { | ||||
|   TORCH_CHECK(a.dtype() == torch::kInt8); | ||||
|   TORCH_CHECK(b.dtype() == torch::kInt8); | ||||
|   TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||||
|   TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||||
|  | ||||
|   using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; | ||||
|   using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; | ||||
|   using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; | ||||
|  | ||||
|   if (out.dtype() == torch::kBFloat16) { | ||||
|     return cutlass_scaled_mm_dq_dispatcher< | ||||
|         cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t, | ||||
|                         TileShape, WarpShape, InstructionShape, 2>>( | ||||
|         out, a, b, a_scales, b_scales); | ||||
|   } else { | ||||
|     TORCH_CHECK(out.dtype() == torch::kFloat16); | ||||
|     return cutlass_scaled_mm_dq_dispatcher< | ||||
|         cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape, | ||||
|                         WarpShape, InstructionShape, 2>>(out, a, b, a_scales, | ||||
|                                                          b_scales); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales) { | ||||
|   TORCH_CHECK(a.dtype() == torch::kInt8); | ||||
|   TORCH_CHECK(b.dtype() == torch::kInt8); | ||||
|   TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||||
|   TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||||
|  | ||||
|   using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; | ||||
|   using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; | ||||
|   using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | ||||
|  | ||||
|   if (out.dtype() == torch::kBFloat16) { | ||||
|     return cutlass_scaled_mm_dq_dispatcher< | ||||
|         cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t, | ||||
|                         TileShape, WarpShape, InstructionShape, 5>>( | ||||
|         out, a, b, a_scales, b_scales); | ||||
|   } else { | ||||
|     TORCH_CHECK(out.dtype() == torch::kFloat16); | ||||
|     return cutlass_scaled_mm_dq_dispatcher< | ||||
|         cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape, | ||||
|                         WarpShape, InstructionShape, 5>>(out, a, b, a_scales, | ||||
|                                                          b_scales); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales) { | ||||
|   using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; | ||||
|   using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; | ||||
|   using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | ||||
|  | ||||
|   TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||||
|   TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||||
|  | ||||
|   if (a.dtype() == torch::kInt8) { | ||||
|     TORCH_CHECK(b.dtype() == torch::kInt8); | ||||
|  | ||||
|     if (out.dtype() == torch::kBFloat16) { | ||||
|       return cutlass_scaled_mm_dq_dispatcher< | ||||
|           cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t, | ||||
|                           TileShape, WarpShape, InstructionShape, 5>>( | ||||
|           out, a, b, a_scales, b_scales); | ||||
|     } else { | ||||
|       assert(out.dtype() == torch::kFloat16); | ||||
|       return cutlass_scaled_mm_dq_dispatcher< | ||||
|           cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t, | ||||
|                           TileShape, WarpShape, InstructionShape, 5>>( | ||||
|           out, a, b, a_scales, b_scales); | ||||
|     } | ||||
|   } else { | ||||
|     TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); | ||||
|     TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); | ||||
|  | ||||
|     if (out.dtype() == torch::kBFloat16) { | ||||
|       return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< | ||||
|           cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t, | ||||
|           TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales, | ||||
|                                                       b_scales); | ||||
|     } else { | ||||
|       TORCH_CHECK(out.dtype() == torch::kFloat16); | ||||
|       return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm< | ||||
|           cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t, | ||||
|           TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales, | ||||
|                                                       b_scales); | ||||
|     } | ||||
|   } | ||||
| } | ||||
							
								
								
									
										325
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										325
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,325 @@ | ||||
| // clang-format will break include orders | ||||
| // clang-format off | ||||
| #include <cudaTypedefs.h> | ||||
|  | ||||
| #if defined CUDA_VERSION && CUDA_VERSION >= 12000 | ||||
|  | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
|  | ||||
| #include <iostream> | ||||
| #include <sstream> | ||||
| #include <vector> | ||||
|  | ||||
| #include "cutlass/cutlass.h" | ||||
|  | ||||
| #include "cute/tensor.hpp" | ||||
| #include "cute/atom/mma_atom.hpp" | ||||
| #include "cutlass/numeric_types.h" | ||||
|  | ||||
| #include "cutlass/util/device_memory.h" | ||||
|  | ||||
| #include "cutlass/gemm/device/gemm_universal_adapter.h" | ||||
| #include "cutlass/gemm/kernel/gemm_universal.hpp" | ||||
| #include "cutlass/epilogue/collective/collective_builder.hpp" | ||||
| #include "cutlass/gemm/collective/collective_builder.hpp" | ||||
|  | ||||
| #include "broadcast_load_epilogue_c3x.hpp" | ||||
| #include "common.hpp" | ||||
| // clang-format on | ||||
|  | ||||
| using namespace cute; | ||||
|  | ||||
| /* | ||||
|    This defines a quantized GEMM operation with dequantized output, similar to | ||||
|    torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for | ||||
|    NVIDIA GPUs with sm90a (Hopper) or later. | ||||
|  | ||||
|    A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or | ||||
|    per-row. B can be quantized per-tensor or per-column. | ||||
|    Any combination of per-tensor and per-row or column is supported. | ||||
|    A and B must have symmetric quantization (zero point == 0). | ||||
|  | ||||
|    So the GEMM operation is D = (a_scales * A) (b_scales * B), where the | ||||
|    scales are applied elementwise with numpy-style broadcasting. | ||||
|  | ||||
|    ScaleA and ScaleB define the epilogue functions that apply the scales for | ||||
|    the A and B operands respectively. These scales may be either per-tensor or | ||||
|    per row or column. | ||||
| */ | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| uint32_t next_pow_2(uint32_t const num) { | ||||
|   if (num <= 1) return num; | ||||
|   return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); | ||||
| } | ||||
|  | ||||
| template <typename ElementAB_, typename ElementD_, typename TileShape, | ||||
|           typename ClusterShape, typename KernelSchedule, | ||||
|           typename EpilogueSchedule> | ||||
| struct cutlass_3x_gemm { | ||||
|   using ElementAB = ElementAB_; | ||||
|   using ElementD = ElementD_; | ||||
|   using ElementAcc = | ||||
|       typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, | ||||
|                                 float>::type; | ||||
|  | ||||
|   using EpilogueDescriptor = | ||||
|       cutlass::epilogue::collective::detail::EpilogueDescriptor< | ||||
|           TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, | ||||
|           ElementD, EpilogueSchedule>; | ||||
|  | ||||
|   using Accum = cutlass::epilogue::fusion::Sm90AccFetch; | ||||
|  | ||||
|   using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< | ||||
|       0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, | ||||
|       Stride<Int<1>, Int<0>, Int<0>>>; | ||||
|  | ||||
|   using ScaleBDescriptor = | ||||
|       cutlass::epilogue::collective::detail::RowBroadcastDescriptor< | ||||
|           EpilogueDescriptor, float>; | ||||
|  | ||||
|   using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< | ||||
|       ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, | ||||
|       typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>; | ||||
|  | ||||
|   using Compute0 = cutlass::epilogue::fusion::Sm90Compute< | ||||
|       cutlass::multiplies, float, float, | ||||
|       cutlass::FloatRoundStyle::round_to_nearest>; | ||||
|  | ||||
|   using EVTCompute0 = | ||||
|       cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>; | ||||
|  | ||||
|   using Compute1 = cutlass::epilogue::fusion::Sm90Compute< | ||||
|       cutlass::multiplies, ElementD, float, | ||||
|       cutlass::FloatRoundStyle::round_to_nearest>; | ||||
|  | ||||
|   using EVTCompute1 = | ||||
|       cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>; | ||||
|  | ||||
|   using StrideD = Stride<int64_t, Int<1>, Int<0>>; | ||||
|   using ElementC = void; | ||||
|   using StrideC = StrideD; | ||||
|  | ||||
|   using CollectiveEpilogue = | ||||
|       typename cutlass::epilogue::collective::CollectiveBuilder< | ||||
|           cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, | ||||
|           ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, | ||||
|           ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, | ||||
|           EpilogueSchedule, EVTCompute1>::CollectiveOp; | ||||
|  | ||||
|   static constexpr size_t CEStorageSize = | ||||
|       sizeof(typename CollectiveEpilogue::SharedStorage); | ||||
|   using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< | ||||
|       static_cast<int>(CEStorageSize)>; | ||||
|  | ||||
|   // clang-format off | ||||
|   using CollectiveMainloop = | ||||
|       typename cutlass::gemm::collective::CollectiveBuilder< | ||||
|           cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,  | ||||
|           ElementAB, cutlass::layout::RowMajor, 16,  | ||||
|           ElementAB, cutlass::layout::ColumnMajor, 16,  | ||||
|           ElementAcc, TileShape, ClusterShape, | ||||
|           Stages, | ||||
|           KernelSchedule>::CollectiveOp; | ||||
|   // clang-format on | ||||
|  | ||||
|   using KernelType = cutlass::gemm::kernel::GemmUniversal< | ||||
|       cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, | ||||
|       cutlass::gemm::PersistentScheduler>; | ||||
|  | ||||
|   struct GemmKernel : public KernelType {}; | ||||
| }; | ||||
|  | ||||
| template <typename Gemm> | ||||
| void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                      torch::Tensor const& b, | ||||
|                                      torch::Tensor const& a_scales, | ||||
|                                      torch::Tensor const& b_scales) { | ||||
|   using ElementAB = typename Gemm::ElementAB; | ||||
|   using ElementD = typename Gemm::ElementD; | ||||
|  | ||||
|   int32_t m = a.size(0); | ||||
|   int32_t n = b.size(1); | ||||
|   int32_t k = a.size(1); | ||||
|  | ||||
|   int64_t lda = a.stride(0); | ||||
|   int64_t ldb = b.stride(1); | ||||
|   int64_t ldc = out.stride(0); | ||||
|  | ||||
|   using StrideA = Stride<int64_t, Int<1>, Int<0>>; | ||||
|   using StrideB = Stride<int64_t, Int<1>, Int<0>>; | ||||
|   using StrideC = typename Gemm::StrideC; | ||||
|  | ||||
|   StrideA a_stride{lda, Int<1>{}, Int<0>{}}; | ||||
|   StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; | ||||
|   StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; | ||||
|  | ||||
|   using GemmKernel = typename Gemm::GemmKernel; | ||||
|   typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; | ||||
|  | ||||
|   auto a_ptr = static_cast<ElementAB*>(a.data_ptr()); | ||||
|   auto b_ptr = static_cast<ElementAB*>(b.data_ptr()); | ||||
|   typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, | ||||
|                                                        b_stride}; | ||||
|  | ||||
|   auto c_ptr = static_cast<ElementD*>(out.data_ptr()); | ||||
|   typename GemmKernel::EpilogueArguments epilogue_args{ | ||||
|       {}, c_ptr, c_stride, c_ptr, c_stride}; | ||||
|  | ||||
|   typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, | ||||
|                                       prob_shape, mainloop_args, epilogue_args}; | ||||
|  | ||||
|   using ScaleA_Args = typename Gemm::ScaleA::Arguments; | ||||
|   using ScaleB_Args = typename Gemm::ScaleB::Arguments; | ||||
|  | ||||
|   ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}}; | ||||
|   ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}}; | ||||
|  | ||||
|   args.epilogue.thread = {a_args, {b_args}}; | ||||
|  | ||||
|   // Launch the CUTLASS GEMM kernel. | ||||
|   using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; | ||||
|   GemmOp gemm_op; | ||||
|   CUTLASS_CHECK(gemm_op.can_implement(args)); | ||||
|  | ||||
|   size_t workspace_size = gemm_op.get_workspace_size(args); | ||||
|   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | ||||
|  | ||||
|   auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); | ||||
|  | ||||
|   cutlass::Status status = gemm_op.run(args, workspace.get(), stream); | ||||
|   CUTLASS_CHECK(status); | ||||
| } | ||||
|  | ||||
| template <typename InType, typename OutType, int32_t M> | ||||
| struct sm90_fp8_config { | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = | ||||
|       cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; | ||||
|   using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||||
|   using TileShape = Shape<_128, _128, _128>; | ||||
|   using ClusterShape = Shape<_2, _1, _1>; | ||||
|  | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule, | ||||
|                       EpilogueSchedule>; | ||||
| }; | ||||
|  | ||||
| template <typename InType, typename OutType> | ||||
| struct sm90_fp8_config<InType, OutType, 128> { | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = | ||||
|       cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; | ||||
|   using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||||
|   using TileShape = Shape<_64, _128, _128>; | ||||
|   using ClusterShape = Shape<_2, _1, _1>; | ||||
|  | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule, | ||||
|                       EpilogueSchedule>; | ||||
| }; | ||||
|  | ||||
| template <typename InType, typename OutType> | ||||
| struct sm90_fp8_config<InType, OutType, 64> { | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = | ||||
|       cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; | ||||
|   using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||||
|   using TileShape = Shape<_64, _64, _128>; | ||||
|   using ClusterShape = Shape<_1, _8, _1>; | ||||
|  | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule, | ||||
|                       EpilogueSchedule>; | ||||
| }; | ||||
|  | ||||
| }  // namespace | ||||
|  | ||||
| template <typename InType, typename OutType> | ||||
| void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out, | ||||
|                                             torch::Tensor const& a, | ||||
|                                             torch::Tensor const& b, | ||||
|                                             torch::Tensor const& a_scales, | ||||
|                                             torch::Tensor const& b_scales) { | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); | ||||
|   TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); | ||||
|   TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||||
|   TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||||
|  | ||||
|   using Cutlass3xGemmDefault = | ||||
|       typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm; | ||||
|   using Cutlass3xGemmM64 = | ||||
|       typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm; | ||||
|   using Cutlass3xGemmM128 = | ||||
|       typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm; | ||||
|  | ||||
|   uint32_t const m = a.size(0); | ||||
|   uint32_t const mp2 = | ||||
|       std::max(static_cast<uint32_t>(64), next_pow_2(m));  // next power of 2 | ||||
|  | ||||
|   if (mp2 <= 64) { | ||||
|     // m in [1, 64] | ||||
|     return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>( | ||||
|         out, a, b, a_scales, b_scales); | ||||
|   } else if (mp2 <= 128) { | ||||
|     // m in (64, 128] | ||||
|     return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>( | ||||
|         out, a, b, a_scales, b_scales); | ||||
|   } else { | ||||
|     // m in (128, inf) | ||||
|     return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>( | ||||
|         out, a, b, a_scales, b_scales); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales) { | ||||
|   TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||||
|   TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||||
|  | ||||
|   if (a.dtype() == torch::kInt8) { | ||||
|     TORCH_CHECK(b.dtype() == torch::kInt8); | ||||
|  | ||||
|     using TileShape = Shape<_128, _128, _128>; | ||||
|     using ClusterShape = Shape<_1, _2, _1>; | ||||
|     using KernelSchedule = | ||||
|         typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; | ||||
|     using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; | ||||
|  | ||||
|     if (out.dtype() == torch::kBFloat16) { | ||||
|       return cutlass_scaled_mm_dq_dispatcher< | ||||
|           cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape, | ||||
|                           KernelSchedule, EpilogueSchedule>>( | ||||
|           out, a, b, a_scales, b_scales); | ||||
|     } else { | ||||
|       TORCH_CHECK(out.dtype() == torch::kFloat16); | ||||
|  | ||||
|       return cutlass_scaled_mm_dq_dispatcher< | ||||
|           cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape, | ||||
|                           KernelSchedule, EpilogueSchedule>>( | ||||
|           out, a, b, a_scales, b_scales); | ||||
|     } | ||||
|   } else { | ||||
|     TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); | ||||
|     TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); | ||||
|  | ||||
|     if (out.dtype() == torch::kBFloat16) { | ||||
|       return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t, | ||||
|                                                     cutlass::bfloat16_t>( | ||||
|           out, a, b, a_scales, b_scales); | ||||
|     } else { | ||||
|       TORCH_CHECK(out.dtype() == torch::kFloat16); | ||||
|       return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t, | ||||
|                                                     cutlass::half_t>( | ||||
|           out, a, b, a_scales, b_scales); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| #endif | ||||
							
								
								
									
										75
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,75 @@ | ||||
| #include <cudaTypedefs.h> | ||||
|  | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
| #include <torch/extension.h> | ||||
|  | ||||
| void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales); | ||||
|  | ||||
| void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales); | ||||
|  | ||||
| void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales); | ||||
|  | ||||
| #if defined CUDA_VERSION && CUDA_VERSION >= 12000 | ||||
| void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, | ||||
|                                torch::Tensor const& b, | ||||
|                                torch::Tensor const& a_scales, | ||||
|                                torch::Tensor const& b_scales); | ||||
| #endif | ||||
|  | ||||
| void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, | ||||
|                           torch::Tensor const& b, torch::Tensor const& a_scales, | ||||
|                           torch::Tensor const& b_scales) { | ||||
|   int32_t major_capability; | ||||
|   int32_t minor_capability; | ||||
|   cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, | ||||
|                          0); | ||||
|   cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, | ||||
|                          0); | ||||
|   int32_t version_num = major_capability * 10 + minor_capability; | ||||
|  | ||||
|   // Checks for conformality | ||||
|   TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); | ||||
|   TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && | ||||
|               b.size(1) == c.size(1)); | ||||
|   TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); | ||||
|   TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); | ||||
|  | ||||
|   // Check for strides and alignment | ||||
|   TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major | ||||
|   TORCH_CHECK(b.stride(0) == 1);                      // Column-major | ||||
|   TORCH_CHECK(c.stride(0) % 16 == 0 && | ||||
|               b.stride(1) % 16 == 0);  // 16 Byte Alignment | ||||
|   TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); | ||||
|  | ||||
|   at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); | ||||
|  | ||||
|   if (version_num >= 90) { | ||||
|     // Hopper | ||||
|  | ||||
|     // Guard against compilation issues for sm90 kernels | ||||
| #if defined CUDA_VERSION && CUDA_VERSION >= 12000 | ||||
|     cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); | ||||
| #else | ||||
|     cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); | ||||
| #endif | ||||
|   } else if (version_num == 89) { | ||||
|     // Ada Lovelace | ||||
|     cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); | ||||
|   } else if (version_num >= 80) { | ||||
|     // Ampere | ||||
|     cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); | ||||
|   } else { | ||||
|     // Turing | ||||
|     TORCH_CHECK(version_num >= 75); | ||||
|     cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										137
									
								
								csrc/quantization/fp8/amd/hip_float8.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								csrc/quantization/fp8/amd/hip_float8.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,137 @@ | ||||
| #pragma once | ||||
|  | ||||
| #ifdef __HIPCC__ | ||||
|   #include <hip/hip_runtime.h> | ||||
| #else | ||||
|   #include <type_traits> | ||||
|   #include <stdint.h> | ||||
|   #include <math.h> | ||||
|   #include <iostream> | ||||
| #endif | ||||
|  | ||||
| #include "hip_float8_impl.h" | ||||
|  | ||||
| struct alignas(1) hip_fp8 { | ||||
|   struct from_bits_t {}; | ||||
|   HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { | ||||
|     return from_bits_t(); | ||||
|   } | ||||
|   uint8_t data; | ||||
|  | ||||
|   hip_fp8() = default; | ||||
|   HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; | ||||
|   HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; | ||||
|   explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) | ||||
|       : data(v) {} | ||||
|  | ||||
| #ifdef __HIP__MI300__ | ||||
|   // NOTE: ON-DEVICE... always optimal bias | ||||
|   explicit HIP_FP8_DEVICE hip_fp8(float v) | ||||
|       : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} | ||||
|  | ||||
|   explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) | ||||
|       : hip_fp8(static_cast<float>(v)) {} | ||||
|  | ||||
|   // Host only implementation using s/w simulation | ||||
|   explicit HIP_FP8_HOST | ||||
| #else   // __HIP__MI300__ | ||||
|   // both Host and DEVICE for non-MI300 using s/w simulation | ||||
|   explicit HIP_FP8_HOST_DEVICE | ||||
| #endif  // __HIP__MI300__ | ||||
|   hip_fp8(float v) { | ||||
|     data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, | ||||
|                                    true /*clip*/>(v); | ||||
|   } | ||||
|  | ||||
|   explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) | ||||
|       : hip_fp8(static_cast<float>(v)) {} | ||||
|  | ||||
| #ifdef __HIP__MI300__ | ||||
|   // upcast using device specific intrinsic | ||||
|   explicit inline HIP_FP8_DEVICE operator float() const { | ||||
|     float fval; | ||||
|     uint32_t i32val = static_cast<uint32_t>(data); | ||||
|  | ||||
|     // upcast | ||||
|     asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" | ||||
|                  : "=v"(fval) | ||||
|                  : "v"(i32val)); | ||||
|  | ||||
|     return fval; | ||||
|   } | ||||
|  | ||||
|   explicit inline HIP_FP8_HOST operator float() const | ||||
| #else   // __HIP__MI300__ | ||||
|   explicit inline HIP_FP8_HOST_DEVICE operator float() const | ||||
| #endif  // __HIP__MI300__ | ||||
|   { | ||||
|     return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( | ||||
|         data); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| namespace std { | ||||
| inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } | ||||
| inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } | ||||
| HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } | ||||
| }  // namespace std | ||||
|  | ||||
| // Special operator overloading | ||||
| inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { | ||||
|   return os << float(f8); | ||||
| } | ||||
|  | ||||
| // all + operator overloading with mixed types | ||||
| // mixed types, always converts to f32, does computation in f32, and returns | ||||
| // float | ||||
| inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { | ||||
|   return (fa + float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { | ||||
|   return (float(a) + fb); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { | ||||
|   return hip_fp8(float(a) + float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { | ||||
|   return a = hip_fp8(float(a) + float(b)); | ||||
| } | ||||
|  | ||||
| // overloading multiplication, always returns float, | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { | ||||
|   return float(a) * float(b); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { | ||||
|   return (a * float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { | ||||
|   return (float(a) * b); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { | ||||
|   return ((float)a * float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { | ||||
|   return ((float)a * float(b)); | ||||
| } | ||||
|  | ||||
| // overloading for compare | ||||
| inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { | ||||
|   return (a.data == b.data); | ||||
| } | ||||
| inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { | ||||
|   return (a.data != b.data); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { | ||||
|   return static_cast<float>(a) >= static_cast<float>(b); | ||||
| } | ||||
| inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { | ||||
|   return static_cast<float>(a) > static_cast<float>(b); | ||||
| } | ||||
							
								
								
									
										316
									
								
								csrc/quantization/fp8/amd/hip_float8_impl.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										316
									
								
								csrc/quantization/fp8/amd/hip_float8_impl.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,316 @@ | ||||
| #pragma once | ||||
|  | ||||
| #if defined(__HIPCC__) && \ | ||||
|     (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) | ||||
|   #define __HIP__MI300__ | ||||
| #endif | ||||
|  | ||||
| #ifdef __HIPCC__ | ||||
|   #define HIP_FP8_HOST_DEVICE __host__ __device__ | ||||
|   #define HIP_FP8_HOST __host__ | ||||
|   #define HIP_FP8_DEVICE __device__ | ||||
| #else | ||||
|   #define HIP_FP8_HOST_DEVICE | ||||
|   #define HIP_FP8_HOST | ||||
|   #define HIP_FP8_DEVICE | ||||
| #endif | ||||
|  | ||||
| namespace hip_fp8_impl { | ||||
|  | ||||
| #ifdef __HIP__MI300__ | ||||
| HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { | ||||
|   uint8_t i8data; | ||||
|   union { | ||||
|     float fval; | ||||
|     uint32_t i32val; | ||||
|     uint8_t i8val[4];  // NOTE: not endian independent | ||||
|   } val; | ||||
|  | ||||
|   uint32_t ival = 0; | ||||
|   val.fval = v; | ||||
|  | ||||
|   if ((val.i32val & 0x7F800000) != | ||||
|       0x7F800000) {  /// propagate NAN/INF, no clipping | ||||
|     val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); | ||||
|   } | ||||
|  | ||||
|   ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, | ||||
|                                          false);  // false -> WORD0 | ||||
|   val.i32val = ival; | ||||
|   i8data = val.i8val[0]; | ||||
|  | ||||
|   return i8data; | ||||
| } | ||||
| #endif  // __HIP__MI300__ | ||||
|  | ||||
| HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } | ||||
| #if defined(__HIPCC__) || defined(__CUDA_ARCH__) | ||||
| HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } | ||||
| #endif | ||||
|  | ||||
| template <int we, int wm, typename T, bool negative_zero_nan, bool clip> | ||||
| HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, | ||||
|                                       uint32_t rng = 0) { | ||||
| #ifdef __HIPCC__ | ||||
|   constexpr bool is_half = std::is_same<T, _Float16>::value; | ||||
| #else | ||||
|   constexpr bool is_half = false; | ||||
| #endif | ||||
|   constexpr bool is_float = std::is_same<T, float>::value; | ||||
|   static_assert(wm + we == 7, "wm+we==7"); | ||||
|   static_assert(is_half || is_float, "Only half and float can be cast to f8"); | ||||
|  | ||||
|   const int mfmt = (sizeof(T) == 4) ? 23 : 10; | ||||
|   uint32_t x; | ||||
|   if (sizeof(T) == 4) { | ||||
|     x = reinterpret_cast<uint32_t&>(_x); | ||||
|   } else { | ||||
|     x = reinterpret_cast<uint16_t&>(_x); | ||||
|   } | ||||
|  | ||||
|   uint32_t head, mantissa; | ||||
|   int exponent, bias; | ||||
|   uint32_t sign; | ||||
|  | ||||
|   if (sizeof(T) == 4) { | ||||
|     head = x & 0xFF800000; | ||||
|     mantissa = x & 0x7FFFFF; | ||||
|     exponent = (head >> 23) & 0xFF; | ||||
|     sign = head >> 31; | ||||
|     bias = 127; | ||||
|   } else { | ||||
|     head = x & 0xFC00; | ||||
|     mantissa = x & 0x3FF; | ||||
|     exponent = (head >> 10) & 0x1F; | ||||
|     sign = head >> 15; | ||||
|     bias = 15; | ||||
|   } | ||||
|  | ||||
|   uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); | ||||
|  | ||||
|   // Deal with inf and NaNs | ||||
|   if (negative_zero_nan) { | ||||
|     if (sizeof(T) == 4) { | ||||
|       if ((x & 0x7F800000) == 0x7F800000) { | ||||
|         return 0x80; | ||||
|       } | ||||
|     } else { | ||||
|       // if(__hisinf(x) || __hisnan(x)) | ||||
|       if ((x & 0x7C00) == 0x7C00) { | ||||
|         return 0x80; | ||||
|       } | ||||
|     } | ||||
|   } else { | ||||
|     if (sizeof(T) == 4) { | ||||
|       if ((x & 0x7F800000) == 0x7F800000) { | ||||
|         return signed_inf + (mantissa != 0 ? 1 : 0); | ||||
|       } | ||||
|     } else { | ||||
|       if ((x & 0x7C00) == 0x7C00) { | ||||
|         return signed_inf + (mantissa != 0 ? 1 : 0); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   if (x == 0) { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   // First need to check if it is normal or denorm as there is a difference of | ||||
|   // implicit 1 Then need to adjust the exponent to align with the F8 exponent, | ||||
|   // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng | ||||
|   // to mantissa and truncate. And for RNE, no need to add rng. Then probably | ||||
|   // need to check whether there is carry and adjust exponent and mantissa again | ||||
|  | ||||
|   // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent | ||||
|   // bits | ||||
|   const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); | ||||
|   const int f8_denormal_act_exponent = | ||||
|       1 - f8_bias;  // actual exponent of f8 denormal | ||||
|   // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) | ||||
|   // f8_exponent is the converted f8 exponent with bias encoding | ||||
|   // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, | ||||
|   // the difference needs to be adjusted and mantissa shifted | ||||
|   int act_exponent, f8_exponent, exponent_diff; | ||||
|  | ||||
|   if (exponent == 0) {  // fp32/fp16 is in denormal. | ||||
|     /* fp32 denormal is below 2^-127 so it is usually not a concern here, we | ||||
| mostly concern fp16 here. In this case, f8 is usually in denormal. But there | ||||
| could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has | ||||
| exponent bias 16. It means that there are some numbers in fp16 denormal but they | ||||
| are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers | ||||
| where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 | ||||
| (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1  */ | ||||
|     act_exponent = exponent - bias + 1; | ||||
|     exponent_diff = | ||||
|         f8_denormal_act_exponent - | ||||
|         act_exponent;  // actual exponent is exponent-bias+1 as it is denormal | ||||
|   } else {             // fp32/fp16 is normal with implicit 1 | ||||
|     act_exponent = exponent - bias; | ||||
|     if (act_exponent <= f8_denormal_act_exponent) { | ||||
|       /* This is the case where fp32/fp16 is normal but it is in f8 denormal | ||||
| range. For example fp8 nanoo mode, denormal exponent is -7, but if the | ||||
| fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, | ||||
| Therefore it needs to be adjust to -6 and mantissa shift right by 1. | ||||
| So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ | ||||
|       exponent_diff = f8_denormal_act_exponent - act_exponent; | ||||
|     } else {              // both fp32/fp16 and f8 are in normal range | ||||
|       exponent_diff = 0;  // exponent_diff=0 does not mean there is no | ||||
|                           // difference for this case, act_exponent could be | ||||
|                           // larger. Just that it does not need shift mantissa | ||||
|     } | ||||
|     mantissa += (1 << mfmt);  // Add the implicit 1 into mantissa | ||||
|   } | ||||
|  | ||||
|   bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == | ||||
|                   static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1)); | ||||
|   /* This part is a bit tricky. The judgment of whether it is a tie needs to be | ||||
|  done before we shift right as shift right could rip off some residual part | ||||
|  and make something not midpoint look like midpoint. For example, the fp16 | ||||
|  number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after | ||||
|  shift right by 4 bits, it would look like midpoint. | ||||
| */ | ||||
|  | ||||
|   if (exponent_diff > 0) { | ||||
|     mantissa >>= exponent_diff; | ||||
|   } else if (exponent_diff == -1) { | ||||
|     mantissa <<= -exponent_diff; | ||||
|   } | ||||
|   bool implicit_one = mantissa & (1 << mfmt); | ||||
|   // if there is no implicit 1, it  means the f8 is denormal and need to adjust | ||||
|   // to denorm exponent | ||||
|   f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + | ||||
|                 f8_bias - (implicit_one ? 0 : 1); | ||||
|  | ||||
|   // Now we have the exponent and mantissa adjusted | ||||
|   uint32_t drop_mask = (1 << (mfmt - wm)) - 1; | ||||
|   bool odd = mantissa & (1 << (mfmt - wm));  // if the least significant bit | ||||
|                                              // that is not truncated is 1 | ||||
|   mantissa += | ||||
|       (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & | ||||
|       drop_mask; | ||||
|  | ||||
|   // Now we deal with overflow | ||||
|   if (f8_exponent == 0) { | ||||
|     if ((1 << mfmt) & mantissa) { | ||||
|       f8_exponent = 1;  // denormal overflow to become normal, promote exponent | ||||
|     } | ||||
|   } else { | ||||
|     if ((1 << (mfmt + 1)) & mantissa) { | ||||
|       mantissa >>= 1; | ||||
|       f8_exponent++; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   mantissa >>= (mfmt - wm); | ||||
|  | ||||
|   // above range: quantize to maximum possible float of the same sign | ||||
|   const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); | ||||
|   if (f8_exponent > max_exp) { | ||||
|     if (clip) { | ||||
|       mantissa = (1 << wm) - 1; | ||||
|       f8_exponent = max_exp; | ||||
|     } else { | ||||
|       return signed_inf; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (f8_exponent == 0 && mantissa == 0) { | ||||
|     return negative_zero_nan ? 0 : (sign << 7); | ||||
|   } | ||||
|   mantissa &= (1 << wm) - 1; | ||||
|   return (sign << 7) | (f8_exponent << wm) | mantissa; | ||||
| } | ||||
|  | ||||
| template <int we, int wm, typename T = float, bool negative_zero_nan = true> | ||||
| inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { | ||||
| #ifdef __HIPCC__ | ||||
|   constexpr bool is_half = std::is_same<T, _Float16>::value; | ||||
| #else | ||||
|   constexpr bool is_half = false; | ||||
| #endif | ||||
|   constexpr bool is_float = std::is_same<T, float>::value; | ||||
|   static_assert(is_half || is_float, "only half and float are supported"); | ||||
|  | ||||
|   constexpr int weo = is_half ? 5 : 8; | ||||
|   constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); | ||||
|  | ||||
|   T fInf, fNegInf, fNaN, fNeg0; | ||||
|  | ||||
| #ifdef __HIPCC__ | ||||
|   if (is_half) { | ||||
|     const uint16_t ihInf = 0x7C00; | ||||
|     const uint16_t ihNegInf = 0xFC00; | ||||
|     const uint16_t ihNaN = 0x7C01; | ||||
|     const uint16_t ihNeg0 = 0x8000; | ||||
|     fInf = reinterpret_cast<const _Float16&>(ihInf); | ||||
|     fNegInf = reinterpret_cast<const _Float16&>(ihNegInf); | ||||
|     fNaN = reinterpret_cast<const _Float16&>(ihNaN); | ||||
|     fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0); | ||||
|   } else | ||||
| #endif | ||||
|       if (is_float) { | ||||
|     const uint32_t ifInf = 0x7F800000; | ||||
|     const uint32_t ifNegInf = 0xFF800000; | ||||
|     const uint32_t ifNaN = 0x7F800001; | ||||
|     const uint32_t ifNeg0 = 0x80000000; | ||||
|     fInf = reinterpret_cast<const float&>(ifInf); | ||||
|     fNegInf = reinterpret_cast<const float&>(ifNegInf); | ||||
|     fNaN = reinterpret_cast<const float&>(ifNaN); | ||||
|     fNeg0 = reinterpret_cast<const float&>(ifNeg0); | ||||
|   } | ||||
|  | ||||
|   if (x == 0) { | ||||
|     return 0; | ||||
|   } | ||||
|  | ||||
|   uint32_t sign = x >> 7; | ||||
|   uint32_t mantissa = x & ((1 << wm) - 1); | ||||
|   int exponent = (x & 0x7F) >> wm; | ||||
|   if (negative_zero_nan) { | ||||
|     if (x == 0x80) { | ||||
|       return fNaN; | ||||
|     } | ||||
|   } else { | ||||
|     if (x == 0x80) { | ||||
|       return fNeg0; | ||||
|     } | ||||
|     if (exponent == ((1 << we) - 1)) { | ||||
|       return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; | ||||
|     } | ||||
|   } | ||||
|   typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; | ||||
|   if (we == 5 && is_half && !negative_zero_nan) { | ||||
|     retval = x << 8; | ||||
|     return reinterpret_cast<const T&>(retval); | ||||
|   } | ||||
|  | ||||
|   const int exp_low_cutoff = | ||||
|       (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); | ||||
|  | ||||
|   // subnormal input | ||||
|   if (exponent == 0) { | ||||
|     // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above | ||||
|     int sh = 1 + clz(mantissa) - (32 - wm); | ||||
|     mantissa <<= sh; | ||||
|     exponent += 1 - sh; | ||||
|     mantissa &= ((1 << wm) - 1); | ||||
|   } | ||||
|   exponent += exp_low_cutoff - 1; | ||||
|   mantissa <<= wmo - wm; | ||||
|  | ||||
|   // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) | ||||
|   if (exponent <= 0) { | ||||
|     mantissa |= 1 << wmo; | ||||
|     mantissa >>= 1 - exponent; | ||||
|     exponent = 0; | ||||
|   } | ||||
|  | ||||
|   if (sizeof(T) == 2) { | ||||
|     retval = (sign << 15) | (exponent << 10) | mantissa; | ||||
|   } else { | ||||
|     retval = (sign << 31) | (exponent << 23) | mantissa; | ||||
|   } | ||||
|   return reinterpret_cast<const T&>(retval); | ||||
| } | ||||
|  | ||||
| }  // namespace hip_fp8_impl | ||||
							
								
								
									
										575
									
								
								csrc/quantization/fp8/amd/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										575
									
								
								csrc/quantization/fp8/amd/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,575 @@ | ||||
| #pragma once | ||||
| #include "hip_float8.h" | ||||
|  | ||||
| #include <hip/hip_fp16.h> | ||||
| #include <hip/hip_bf16.h> | ||||
| #include <hip/hip_bfloat16.h> | ||||
|  | ||||
| #include "../../../attention/dtype_fp8.cuh" | ||||
| #include "../../../attention/dtype_float32.cuh" | ||||
| #include "../../../attention/dtype_bfloat16.cuh" | ||||
|  | ||||
| namespace vllm { | ||||
| #ifdef USE_ROCM | ||||
|  | ||||
| namespace fp8 { | ||||
|   #ifdef ENABLE_FP8 | ||||
|  | ||||
| template <typename Tout, typename Tin> | ||||
| __inline__ __device__ Tout vec_conversion(const Tin& x) { | ||||
|   return x; | ||||
| } | ||||
|  | ||||
| template <typename Tout, typename Tin> | ||||
| __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, | ||||
|                                                  const float scale) { | ||||
|   return x; | ||||
| } | ||||
|  | ||||
| // fp8 -> half | ||||
| template <> | ||||
| __inline__ __device__ uint16_t | ||||
| vec_conversion<uint16_t, uint8_t>(const uint8_t& a) { | ||||
|   hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|   __half_raw res; | ||||
|   res.data = static_cast<float>(f8); | ||||
|   return res.x; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t | ||||
| vec_conversion<uint32_t, uint16_t>(const uint16_t& a) { | ||||
|     #if defined(__HIP__MI300__) && \ | ||||
|         defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|   const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|   union { | ||||
|     __half2_raw h2r; | ||||
|     uint32_t ui32; | ||||
|   } tmp; | ||||
|   tmp.h2r.x.data = f2[0]; | ||||
|   tmp.h2r.y.data = f2[1]; | ||||
|   return tmp.ui32; | ||||
|     #else | ||||
|   union { | ||||
|     uint16_t u16[2]; | ||||
|     uint32_t u32; | ||||
|   } tmp; | ||||
|  | ||||
|   tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a)); | ||||
|   tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U)); | ||||
|   return tmp.u32; | ||||
|     #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) { | ||||
|   union { | ||||
|     uint2 u32x2; | ||||
|     uint32_t u32[2]; | ||||
|   } tmp; | ||||
|   tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a); | ||||
|   tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U)); | ||||
|   return tmp.u32x2; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) { | ||||
|   union { | ||||
|     uint4 u64x2; | ||||
|     uint2 u64[2]; | ||||
|   } tmp; | ||||
|   tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x); | ||||
|   tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y); | ||||
|   return tmp.u64x2; | ||||
| } | ||||
|  | ||||
| using __nv_bfloat16 = __hip_bfloat16; | ||||
|  | ||||
| // fp8 -> __nv_bfloat16 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat16 | ||||
| vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { | ||||
|   hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|   float f{f8}; | ||||
|   return __float2bfloat16(f); | ||||
| } | ||||
|  | ||||
| using __nv_bfloat162 = __hip_bfloat162; | ||||
|  | ||||
| // fp8x2 -> __nv_bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 | ||||
| vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { | ||||
|   __nv_bfloat162 res; | ||||
|   res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); | ||||
|   res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> bf16_4_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t | ||||
| vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) { | ||||
|   bf16_4_t res; | ||||
|   res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); | ||||
|   res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> bf16_8_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) { | ||||
|   bf16_4_t tmp1, tmp2; | ||||
|   tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); | ||||
|   tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); | ||||
|   bf16_8_t res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8 -> float | ||||
| template <> | ||||
| __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) { | ||||
|   hip_fp8 fp8{a, hip_fp8::from_bits()}; | ||||
|   return static_cast<float>(fp8); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> float2 | ||||
| template <> | ||||
| __inline__ __device__ float2 | ||||
| vec_conversion<float2, uint16_t>(const uint16_t& a) { | ||||
|     #if defined(__HIP__MI300__) && \ | ||||
|         defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|   float2 res; | ||||
|   const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|   res.x = f2[0]; | ||||
|   res.y = f2[1]; | ||||
|   return res; | ||||
|     #else | ||||
|   float2 res; | ||||
|   res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a)); | ||||
|   res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U)); | ||||
|   return res; | ||||
|     #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ Float4_ | ||||
| vec_conversion<Float4_, uint32_t>(const uint32_t& a) { | ||||
|   Float4_ res; | ||||
|   res.x = vec_conversion<float2, uint16_t>((uint16_t)a); | ||||
|   res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> float8 | ||||
| template <> | ||||
| __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) { | ||||
|   Float4_ tmp1, tmp2; | ||||
|   tmp1 = vec_conversion<Float4_, uint32_t>(a.x); | ||||
|   tmp2 = vec_conversion<Float4_, uint32_t>(a.y); | ||||
|   Float8_ res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // half -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t | ||||
| vec_conversion<uint8_t, uint16_t>(const uint16_t& a) { | ||||
|   __half_raw tmp; | ||||
|   tmp.x = a; | ||||
|  | ||||
|   hip_fp8 f8{static_cast<float>(tmp.data)}; | ||||
|   return f8.data; | ||||
| } | ||||
|  | ||||
| // bf16 -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t | ||||
| vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) { | ||||
|   hip_fp8 res{__bfloat162float(a)}; | ||||
|   return res.data; | ||||
| } | ||||
|  | ||||
| // float -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) { | ||||
|   hip_fp8 f8(a); | ||||
|   return f8.data; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 | ||||
| vec_conversion<float4, uint32_t>(const uint32_t& a) { | ||||
|   Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); | ||||
|   float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // float2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t | ||||
| vec_conversion<uint32_t, float2>(const float2& a) { | ||||
|   union { | ||||
|     half2 float16; | ||||
|     uint32_t uint32; | ||||
|   }; | ||||
|  | ||||
|   float16 = __float22half2_rn(a); | ||||
|   return uint32; | ||||
| } | ||||
|  | ||||
| // Float4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) { | ||||
|   uint2 b; | ||||
|   float2 val; | ||||
|   val.x = a.x.x; | ||||
|   val.y = a.x.y; | ||||
|   b.x = vec_conversion<uint32_t, float2>(val); | ||||
|  | ||||
|   val.x = a.y.x; | ||||
|   val.y = a.y.y; | ||||
|   b.y = vec_conversion<uint32_t, float2>(val); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| // Float4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) { | ||||
|   float4 b; | ||||
|   b.x = a.x.x; | ||||
|   b.y = a.x.y; | ||||
|   b.z = a.y.x; | ||||
|   b.w = a.y.y; | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| // Float8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) { | ||||
|   uint4 b; | ||||
|   b.x = vec_conversion<uint32_t, float2>(a.x); | ||||
|   b.y = vec_conversion<uint32_t, float2>(a.y); | ||||
|   b.z = vec_conversion<uint32_t, float2>(a.z); | ||||
|   b.w = vec_conversion<uint32_t, float2>(a.w); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| // float2 -> bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 | ||||
| vec_conversion<__nv_bfloat162, float2>(const float2& a) { | ||||
|   __nv_bfloat162 b = __float22bfloat162_rn(a); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| // Float4 -> bfloat162x2 | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t | ||||
| vec_conversion<bf16_4_t, Float4_>(const Float4_& a) { | ||||
|   bf16_4_t b; | ||||
|   b.x = __float22bfloat162_rn(a.x); | ||||
|   b.y = __float22bfloat162_rn(a.y); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| // Float8 -> bfloat162x4 | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t | ||||
| vec_conversion<bf16_8_t, Float8_>(const Float8_& a) { | ||||
|   bf16_8_t b; | ||||
|   b.x = __float22bfloat162_rn(a.x); | ||||
|   b.y = __float22bfloat162_rn(a.y); | ||||
|   b.z = __float22bfloat162_rn(a.z); | ||||
|   b.w = __float22bfloat162_rn(a.w); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| /* Scaled and vectorized conversions, for data exchange between high and low | ||||
|    precision domains | ||||
|  | ||||
|    Convention of the scale in API, e.g: FP8_data = Quantization( | ||||
|    High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * | ||||
|    scale =>  HP | ||||
|  | ||||
|  */ | ||||
|  | ||||
| // fp8 -> half | ||||
| template <> | ||||
| __inline__ __device__ uint16_t | ||||
| scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) { | ||||
|   hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|   __half_raw res; | ||||
|   res.data = static_cast<float>(f8) * scale; | ||||
|   return res.x; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>( | ||||
|     const uint16_t& a, const float scale) { | ||||
|     #if defined(__HIP__MI300__) && \ | ||||
|         defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|   const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|   union { | ||||
|     __half2_raw h2r; | ||||
|     uint32_t ui32; | ||||
|   } tmp; | ||||
|   tmp.h2r.x.data = f2[0] * scale; | ||||
|   tmp.h2r.y.data = f2[1] * scale; | ||||
|   return tmp.ui32; | ||||
|     #else | ||||
|   union { | ||||
|     uint16_t u16[2]; | ||||
|     uint32_t u32; | ||||
|   } tmp; | ||||
|  | ||||
|   tmp.u16[0] = | ||||
|       scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale); | ||||
|   tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>( | ||||
|       static_cast<uint8_t>(a >> 8U), scale); | ||||
|   return tmp.u32; | ||||
|     #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 | ||||
| scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) { | ||||
|   union { | ||||
|     uint2 u32x2; | ||||
|     uint32_t u32[2]; | ||||
|   } tmp; | ||||
|   tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); | ||||
|   tmp.u32[1] = | ||||
|       scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale); | ||||
|   return tmp.u32x2; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 | ||||
| scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) { | ||||
|   union { | ||||
|     uint4 u64x2; | ||||
|     uint2 u64[2]; | ||||
|   } tmp; | ||||
|   tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale); | ||||
|   tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale); | ||||
|   return tmp.u64x2; | ||||
| } | ||||
|  | ||||
| using __nv_bfloat16 = __hip_bfloat16; | ||||
|  | ||||
| // fp8 -> __nv_bfloat16 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat16 | ||||
| scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, | ||||
|                                               const float scale) { | ||||
|   hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|   float f{f8}; | ||||
|   return __float2bfloat16(f * scale); | ||||
| } | ||||
|  | ||||
| using __nv_bfloat162 = __hip_bfloat162; | ||||
|  | ||||
| // fp8x2 -> __nv_bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 | ||||
| scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, | ||||
|                                                 const float scale) { | ||||
|   __nv_bfloat162 res; | ||||
|   res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); | ||||
|   res.y = | ||||
|       scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> bf16_4_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>( | ||||
|     const uint32_t& a, const float scale) { | ||||
|   bf16_4_t res; | ||||
|   res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); | ||||
|   res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), | ||||
|                                                           scale); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> bf16_8_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t | ||||
| scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) { | ||||
|   bf16_4_t tmp1, tmp2; | ||||
|   tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale); | ||||
|   tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale); | ||||
|   bf16_8_t res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8 -> float | ||||
| template <> | ||||
| __inline__ __device__ float scaled_vec_conversion<float, uint8_t>( | ||||
|     const uint8_t& a, const float scale) { | ||||
|   hip_fp8 fp8{a, hip_fp8::from_bits()}; | ||||
|   return static_cast<float>(fp8) * scale; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> float2 | ||||
| template <> | ||||
| __inline__ __device__ float2 | ||||
| scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) { | ||||
|     #if defined(__HIP__MI300__) && \ | ||||
|         defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|   float2 res; | ||||
|   const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|   res.x = f2[0] * scale; | ||||
|   res.y = f2[1] * scale; | ||||
|   return res; | ||||
|     #else | ||||
|   float2 res; | ||||
|   res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale); | ||||
|   res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), | ||||
|                                                 scale); | ||||
|   return res; | ||||
|     #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ Float4_ | ||||
| scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) { | ||||
|   Float4_ res; | ||||
|   res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale); | ||||
|   res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> float8 | ||||
| template <> | ||||
| __inline__ __device__ Float8_ | ||||
| scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) { | ||||
|   Float4_ tmp1, tmp2; | ||||
|   tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale); | ||||
|   tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale); | ||||
|   Float8_ res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| /* Quantize(HP / scale) => FP8 */ | ||||
|  | ||||
| // TODO(Hai): vectorized to add | ||||
|  | ||||
| // half -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t | ||||
| scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) { | ||||
|   __half_raw tmp; | ||||
|   tmp.x = a; | ||||
|  | ||||
|   hip_fp8 f8{static_cast<float>(tmp.data) / scale}; | ||||
|   return f8.data; | ||||
| } | ||||
|  | ||||
| // bf16 -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( | ||||
|     const __nv_bfloat16& a, const float scale) { | ||||
|   hip_fp8 res{__bfloat162float(a) / scale}; | ||||
|   return res.data; | ||||
| } | ||||
|  | ||||
| // float -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t | ||||
| scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) { | ||||
|   hip_fp8 f8(a / scale); | ||||
|   return f8.data; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 | ||||
| scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) { | ||||
|   Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale); | ||||
|   float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||
|   return res; | ||||
| } | ||||
|   #endif  // ENABLE_FP8 | ||||
|  | ||||
| template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> | ||||
| __inline__ __device__ Tout convert(const Tin& x) { | ||||
|   #ifdef ENABLE_FP8 | ||||
|   if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { | ||||
|     return vec_conversion<Tout, Tin>(x); | ||||
|   } | ||||
|   #endif | ||||
|   assert(false); | ||||
| } | ||||
|  | ||||
| template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> | ||||
| __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { | ||||
|   #ifdef ENABLE_FP8 | ||||
|   if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { | ||||
|     return scaled_vec_conversion<Tout, Tin>(x, scale); | ||||
|   } | ||||
|   #endif | ||||
|   assert(false); | ||||
| } | ||||
|  | ||||
|   // The following macro is used to dispatch the conversion function based on | ||||
|   // the data type of the key and value cache. The FN is a macro that calls a | ||||
|   // function with template<typename scalar_t, typename cache_t, | ||||
|   // Fp8KVCacheDataType kv_dt>. | ||||
|   #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \ | ||||
|     if (KV_DTYPE == "auto") {                                                  \ | ||||
|       if (SRC_DTYPE == at::ScalarType::Float) {                                \ | ||||
|         FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \ | ||||
|       } else if (SRC_DTYPE == at::ScalarType::Half) {                          \ | ||||
|         FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \ | ||||
|       } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \ | ||||
|         FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \ | ||||
|       } else {                                                                 \ | ||||
|         TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ | ||||
|       }                                                                        \ | ||||
|     } else {                                                                   \ | ||||
|       if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \ | ||||
|         if (SRC_DTYPE == at::ScalarType::Float) {                              \ | ||||
|           FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \ | ||||
|         } else if (SRC_DTYPE == at::ScalarType::Half) {                        \ | ||||
|           FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \ | ||||
|         } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \ | ||||
|           FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \ | ||||
|         } else {                                                               \ | ||||
|           TORCH_CHECK(false,                                                   \ | ||||
|                       "Unsupported input type of kv cache: ", SRC_DTYPE);      \ | ||||
|         }                                                                      \ | ||||
|       } else {                                                                 \ | ||||
|         TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \ | ||||
|       }                                                                        \ | ||||
|     } | ||||
|  | ||||
| }  // namespace fp8 | ||||
| #endif  // USE_ROCM | ||||
| }  // namespace vllm | ||||
| @ -1,167 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #ifdef __HIPCC__ | ||||
| #include <hip/hip_runtime.h> | ||||
| #else | ||||
| #include <type_traits> | ||||
| #include <stdint.h> | ||||
| #include <math.h> | ||||
| #include <iostream> | ||||
| #endif | ||||
|  | ||||
| #include "hip_float8_impl.h" | ||||
|  | ||||
| struct alignas(1) hip_fp8 | ||||
| { | ||||
|     struct from_bits_t | ||||
|     { | ||||
|     }; | ||||
|     HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } | ||||
|     uint8_t data; | ||||
|  | ||||
|     hip_fp8() = default; | ||||
|     HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; | ||||
|     HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; | ||||
|     explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) | ||||
|         : data(v) | ||||
|     { | ||||
|     } | ||||
|  | ||||
| #ifdef __HIP__MI300__ | ||||
|     // NOTE: ON-DEVICE... always optimal bias | ||||
|     explicit HIP_FP8_DEVICE hip_fp8(float v) | ||||
|         : data(hip_fp8_impl::to_fp8_from_fp32(v)) | ||||
|     { | ||||
|     } | ||||
|  | ||||
|     explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) | ||||
|         : hip_fp8(static_cast<float>(v)) | ||||
|     { | ||||
|     } | ||||
|  | ||||
|     // Host only implementation using s/w simulation | ||||
|     explicit HIP_FP8_HOST | ||||
| #else  // __HIP__MI300__ | ||||
|     // both Host and DEVICE for non-MI300 using s/w simulation | ||||
|     explicit HIP_FP8_HOST_DEVICE | ||||
| #endif // __HIP__MI300__ | ||||
|     hip_fp8(float v) | ||||
|     { | ||||
|         data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); | ||||
|     } | ||||
|  | ||||
|     explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) | ||||
|         : hip_fp8(static_cast<float>(v)) | ||||
|     { | ||||
|     } | ||||
|  | ||||
| #ifdef __HIP__MI300__ | ||||
|     // upcast using device specific intrinsic | ||||
|     explicit inline HIP_FP8_DEVICE operator float() const | ||||
|     { | ||||
|         float fval; | ||||
|         uint32_t i32val = static_cast<uint32_t>(data); | ||||
|  | ||||
|         // upcast | ||||
|         asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); | ||||
|  | ||||
|         return fval; | ||||
|     } | ||||
|  | ||||
|     explicit inline HIP_FP8_HOST operator float() const | ||||
| #else  // __HIP__MI300__ | ||||
|     explicit inline HIP_FP8_HOST_DEVICE operator float() const | ||||
| #endif // __HIP__MI300__ | ||||
|     { | ||||
|         return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); | ||||
|     } | ||||
| }; | ||||
|  | ||||
| namespace std | ||||
| { | ||||
| inline hip_fp8 sin(hip_fp8 a) | ||||
| { | ||||
|     return hip_fp8(sinf(float(a))); | ||||
| } | ||||
| inline hip_fp8 cos(hip_fp8 a) | ||||
| { | ||||
|     return hip_fp8(cosf(float(a))); | ||||
| } | ||||
| HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) | ||||
| { | ||||
|     return a; | ||||
| } | ||||
| } // namespace std | ||||
|  | ||||
| // Special operator overloading | ||||
| inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) | ||||
| { | ||||
|     return os << float(f8); | ||||
| } | ||||
|  | ||||
| // all + operator overloading with mixed types | ||||
| // mixed types, always converts to f32, does computation in f32, and returns float | ||||
| inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) | ||||
| { | ||||
|     return (fa + float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) | ||||
| { | ||||
|     return (float(a) + fb); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) | ||||
| { | ||||
|     return hip_fp8(float(a) + float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) | ||||
| { | ||||
|     return a = hip_fp8(float(a) + float(b)); | ||||
| } | ||||
|  | ||||
| // overloading multiplication, always returns float, | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) | ||||
| { | ||||
|     return float(a) * float(b); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) | ||||
| { | ||||
|     return (a * float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) | ||||
| { | ||||
|     return (float(a) * b); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) | ||||
| { | ||||
|     return ((float)a * float(b)); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) | ||||
| { | ||||
|     return ((float)a * float(b)); | ||||
| } | ||||
|  | ||||
| // overloading for compare | ||||
| inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) | ||||
| { | ||||
|     return (a.data == b.data); | ||||
| } | ||||
| inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) | ||||
| { | ||||
|     return (a.data != b.data); | ||||
| } | ||||
|  | ||||
| inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) | ||||
| { | ||||
|     return static_cast<float>(a) >= static_cast<float>(b); | ||||
| } | ||||
| inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) | ||||
| { | ||||
|     return static_cast<float>(a) > static_cast<float>(b); | ||||
| } | ||||
| @ -1,316 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) | ||||
| #define __HIP__MI300__ | ||||
| #endif | ||||
|  | ||||
| #ifdef __HIPCC__ | ||||
| #define HIP_FP8_HOST_DEVICE __host__ __device__ | ||||
| #define HIP_FP8_HOST __host__ | ||||
| #define HIP_FP8_DEVICE __device__ | ||||
| #else | ||||
| #define HIP_FP8_HOST_DEVICE | ||||
| #define HIP_FP8_HOST | ||||
| #define HIP_FP8_DEVICE | ||||
| #endif | ||||
|  | ||||
| namespace hip_fp8_impl | ||||
| { | ||||
|  | ||||
| #ifdef __HIP__MI300__ | ||||
| HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) | ||||
| { | ||||
|     uint8_t i8data; | ||||
|     union { | ||||
|         float fval; | ||||
|         uint32_t i32val; | ||||
|         uint8_t i8val[4]; // NOTE: not endian independent | ||||
|     } val; | ||||
|  | ||||
|     uint32_t ival = 0; | ||||
|     val.fval = v; | ||||
|  | ||||
|     if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping | ||||
|         val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); | ||||
|     } | ||||
|  | ||||
|     ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, | ||||
|         false); // false -> WORD0 | ||||
|     val.i32val = ival; | ||||
|     i8data = val.i8val[0]; | ||||
|  | ||||
|     return i8data; | ||||
| } | ||||
| #endif // __HIP__MI300__ | ||||
|  | ||||
| HIP_FP8_HOST inline int clz(uint32_t x) | ||||
| { | ||||
|     return __builtin_clz(x); | ||||
| } | ||||
| #if defined(__HIPCC__) || defined(__CUDA_ARCH__) | ||||
| HIP_FP8_DEVICE inline int clz(uint32_t x) | ||||
| { | ||||
|     return __clz(x); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| template <int we, int wm, typename T, bool negative_zero_nan, bool clip> | ||||
| HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) | ||||
| { | ||||
| #ifdef __HIPCC__ | ||||
|     constexpr bool is_half = std::is_same<T, _Float16>::value; | ||||
| #else | ||||
|     constexpr bool is_half = false; | ||||
| #endif | ||||
|     constexpr bool is_float = std::is_same<T, float>::value; | ||||
|     static_assert(wm + we == 7, "wm+we==7"); | ||||
|     static_assert(is_half || is_float, "Only half and float can be cast to f8"); | ||||
|  | ||||
|     const int mfmt = (sizeof(T) == 4) ? 23 : 10; | ||||
|     uint32_t x; | ||||
|     if (sizeof(T) == 4) { | ||||
|         x = reinterpret_cast<uint32_t&>(_x); | ||||
|     } else { | ||||
|         x = reinterpret_cast<uint16_t&>(_x); | ||||
|     } | ||||
|  | ||||
|     uint32_t head, mantissa; | ||||
|     int exponent, bias; | ||||
|     uint32_t sign; | ||||
|  | ||||
|     if (sizeof(T) == 4) { | ||||
|         head = x & 0xFF800000; | ||||
|         mantissa = x & 0x7FFFFF; | ||||
|         exponent = (head >> 23) & 0xFF; | ||||
|         sign = head >> 31; | ||||
|         bias = 127; | ||||
|     } else { | ||||
|         head = x & 0xFC00; | ||||
|         mantissa = x & 0x3FF; | ||||
|         exponent = (head >> 10) & 0x1F; | ||||
|         sign = head >> 15; | ||||
|         bias = 15; | ||||
|     } | ||||
|  | ||||
|     uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); | ||||
|  | ||||
|     // Deal with inf and NaNs | ||||
|     if (negative_zero_nan) { | ||||
|         if (sizeof(T) == 4) { | ||||
|             if ((x & 0x7F800000) == 0x7F800000) { | ||||
|                 return 0x80; | ||||
|             } | ||||
|         } else { | ||||
|             // if(__hisinf(x) || __hisnan(x)) | ||||
|             if ((x & 0x7C00) == 0x7C00) { | ||||
|                 return 0x80; | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         if (sizeof(T) == 4) { | ||||
|             if ((x & 0x7F800000) == 0x7F800000) { | ||||
|                 return signed_inf + (mantissa != 0 ? 1 : 0); | ||||
|             } | ||||
|         } else { | ||||
|             if ((x & 0x7C00) == 0x7C00) { | ||||
|                 return signed_inf + (mantissa != 0 ? 1 : 0); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     if (x == 0) { | ||||
|         return 0; | ||||
|     } | ||||
|  | ||||
|     // First need to check if it is normal or denorm as there is a difference of | ||||
|     // implicit 1 Then need to adjust the exponent to align with the F8 exponent, | ||||
|     // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng | ||||
|     // to mantissa and truncate. And for RNE, no need to add rng. Then probably | ||||
|     // need to check whether there is carry and adjust exponent and mantissa again | ||||
|  | ||||
|     // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent | ||||
|     // bits | ||||
|     const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); | ||||
|     const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal | ||||
|     // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) | ||||
|     // f8_exponent is the converted f8 exponent with bias encoding | ||||
|     // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, | ||||
|     // the difference needs to be adjusted and mantissa shifted | ||||
|     int act_exponent, f8_exponent, exponent_diff; | ||||
|  | ||||
|     if (exponent == 0) { // fp32/fp16 is in denormal. | ||||
|         /* fp32 denormal is below 2^-127 so it is usually not a concern here, we | ||||
| mostly concern fp16 here. In this case, f8 is usually in denormal. But there | ||||
| could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has | ||||
| exponent bias 16. It means that there are some numbers in fp16 denormal but they | ||||
| are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers | ||||
| where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 | ||||
| (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1  */ | ||||
|         act_exponent = exponent - bias + 1; | ||||
|         exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal | ||||
|     } else {                                                     // fp32/fp16 is normal with implicit 1 | ||||
|         act_exponent = exponent - bias; | ||||
|         if (act_exponent <= f8_denormal_act_exponent) { | ||||
|             /* This is the case where fp32/fp16 is normal but it is in f8 denormal | ||||
|  range. For example fp8 nanoo mode, denormal exponent is -7, but if the | ||||
|  fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, | ||||
|  Therefore it needs to be adjust to -6 and mantissa shift right by 1. | ||||
|  So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ | ||||
|             exponent_diff = f8_denormal_act_exponent - act_exponent; | ||||
|         } else {               // both fp32/fp16 and f8 are in normal range | ||||
|             exponent_diff = 0; // exponent_diff=0 does not mean there is no difference | ||||
|                                // for this case, | ||||
|                                // act_exponent could be larger. Just that it does not need shift mantissa | ||||
|         } | ||||
|         mantissa += (1 << mfmt); // Add the implicit 1 into mantissa | ||||
|     } | ||||
|  | ||||
|     bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == | ||||
|                     static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1)); | ||||
|     /* This part is a bit tricky. The judgment of whether it is a tie needs to be | ||||
|    done before we shift right as shift right could rip off some residual part | ||||
|    and make something not midpoint look like midpoint. For example, the fp16 | ||||
|    number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after | ||||
|    shift right by 4 bits, it would look like midpoint. | ||||
| */ | ||||
|  | ||||
|     if (exponent_diff > 0) { | ||||
|         mantissa >>= exponent_diff; | ||||
|     } else if (exponent_diff == -1) { | ||||
|         mantissa <<= -exponent_diff; | ||||
|     } | ||||
|     bool implicit_one = mantissa & (1 << mfmt); | ||||
|     // if there is no implicit 1, it  means the f8 is denormal and need to adjust | ||||
|     // to denorm exponent | ||||
|     f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); | ||||
|  | ||||
|     // Now we have the exponent and mantissa adjusted | ||||
|     uint32_t drop_mask = (1 << (mfmt - wm)) - 1; | ||||
|     bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that | ||||
|                                               // is not truncated is 1 | ||||
|     mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; | ||||
|  | ||||
|     // Now we deal with overflow | ||||
|     if (f8_exponent == 0) { | ||||
|         if ((1 << mfmt) & mantissa) { | ||||
|             f8_exponent = 1; // denormal overflow to become normal, promote exponent | ||||
|         } | ||||
|     } else { | ||||
|         if ((1 << (mfmt + 1)) & mantissa) { | ||||
|             mantissa >>= 1; | ||||
|             f8_exponent++; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     mantissa >>= (mfmt - wm); | ||||
|  | ||||
|     // above range: quantize to maximum possible float of the same sign | ||||
|     const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); | ||||
|     if (f8_exponent > max_exp) { | ||||
|         if (clip) { | ||||
|             mantissa = (1 << wm) - 1; | ||||
|             f8_exponent = max_exp; | ||||
|         } else { | ||||
|             return signed_inf; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (f8_exponent == 0 && mantissa == 0) { | ||||
|         return negative_zero_nan ? 0 : (sign << 7); | ||||
|     } | ||||
|     mantissa &= (1 << wm) - 1; | ||||
|     return (sign << 7) | (f8_exponent << wm) | mantissa; | ||||
| } | ||||
|  | ||||
| template <int we, int wm, typename T = float, bool negative_zero_nan = true> | ||||
| inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) | ||||
| { | ||||
| #ifdef __HIPCC__ | ||||
|     constexpr bool is_half = std::is_same<T, _Float16>::value; | ||||
| #else | ||||
|     constexpr bool is_half = false; | ||||
| #endif | ||||
|     constexpr bool is_float = std::is_same<T, float>::value; | ||||
|     static_assert(is_half || is_float, "only half and float are supported"); | ||||
|  | ||||
|     constexpr int weo = is_half ? 5 : 8; | ||||
|     constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); | ||||
|  | ||||
|     T fInf, fNegInf, fNaN, fNeg0; | ||||
|  | ||||
| #ifdef __HIPCC__ | ||||
|     if (is_half) { | ||||
|         const uint16_t ihInf = 0x7C00; | ||||
|         const uint16_t ihNegInf = 0xFC00; | ||||
|         const uint16_t ihNaN = 0x7C01; | ||||
|         const uint16_t ihNeg0 = 0x8000; | ||||
|         fInf = reinterpret_cast<const _Float16&>(ihInf); | ||||
|         fNegInf = reinterpret_cast<const _Float16&>(ihNegInf); | ||||
|         fNaN = reinterpret_cast<const _Float16&>(ihNaN); | ||||
|         fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0); | ||||
|     } else | ||||
| #endif | ||||
|         if (is_float) { | ||||
|         const uint32_t ifInf = 0x7F800000; | ||||
|         const uint32_t ifNegInf = 0xFF800000; | ||||
|         const uint32_t ifNaN = 0x7F800001; | ||||
|         const uint32_t ifNeg0 = 0x80000000; | ||||
|         fInf = reinterpret_cast<const float&>(ifInf); | ||||
|         fNegInf = reinterpret_cast<const float&>(ifNegInf); | ||||
|         fNaN = reinterpret_cast<const float&>(ifNaN); | ||||
|         fNeg0 = reinterpret_cast<const float&>(ifNeg0); | ||||
|     } | ||||
|  | ||||
|     if (x == 0) { | ||||
|         return 0; | ||||
|     } | ||||
|  | ||||
|     uint32_t sign = x >> 7; | ||||
|     uint32_t mantissa = x & ((1 << wm) - 1); | ||||
|     int exponent = (x & 0x7F) >> wm; | ||||
|     if (negative_zero_nan) { | ||||
|         if (x == 0x80) { | ||||
|             return fNaN; | ||||
|         } | ||||
|     } else { | ||||
|         if (x == 0x80) { | ||||
|             return fNeg0; | ||||
|         } | ||||
|         if (exponent == ((1 << we) - 1)) { | ||||
|             return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; | ||||
|         } | ||||
|     } | ||||
|     typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval; | ||||
|     if (we == 5 && is_half && !negative_zero_nan) { | ||||
|         retval = x << 8; | ||||
|         return reinterpret_cast<const T&>(retval); | ||||
|     } | ||||
|  | ||||
|     const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); | ||||
|  | ||||
|     // subnormal input | ||||
|     if (exponent == 0) { | ||||
|         // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above | ||||
|         int sh = 1 + clz(mantissa) - (32 - wm); | ||||
|         mantissa <<= sh; | ||||
|         exponent += 1 - sh; | ||||
|         mantissa &= ((1 << wm) - 1); | ||||
|     } | ||||
|     exponent += exp_low_cutoff - 1; | ||||
|     mantissa <<= wmo - wm; | ||||
|  | ||||
|     // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) | ||||
|     if (exponent <= 0) { | ||||
|         mantissa |= 1 << wmo; | ||||
|         mantissa >>= 1 - exponent; | ||||
|         exponent = 0; | ||||
|     } | ||||
|  | ||||
|     if (sizeof(T) == 2) { | ||||
|         retval = (sign << 15) | (exponent << 10) | mantissa; | ||||
|     } else { | ||||
|         retval = (sign << 31) | (exponent << 23) | mantissa; | ||||
|     } | ||||
|     return reinterpret_cast<const T&>(retval); | ||||
| } | ||||
|  | ||||
| } // namespace hip_fp8_impl | ||||
| @ -1,517 +0,0 @@ | ||||
| #pragma once | ||||
| #include "hip_float8.h" | ||||
|  | ||||
| #include <hip/hip_fp16.h> | ||||
| #include <hip/hip_bf16.h> | ||||
| #include <hip/hip_bfloat16.h> | ||||
|  | ||||
| #include "../../../attention/dtype_float32.cuh" | ||||
| #include "../../../attention/dtype_bfloat16.cuh" | ||||
|  | ||||
| namespace vllm | ||||
| { | ||||
| namespace fp8_e4m3 { | ||||
| template <typename Tout, typename Tin> | ||||
| __inline__ __device__ Tout vec_conversion(const Tin& x) | ||||
| { | ||||
|     return x; | ||||
| } | ||||
|  | ||||
| template <typename Tout, typename Tin> | ||||
| __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) | ||||
| { | ||||
|     return x; | ||||
| } | ||||
|  | ||||
| // fp8 -> half | ||||
| template <> | ||||
| __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a) | ||||
| { | ||||
|     hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|     __half_raw res; | ||||
|     res.data = static_cast<float>(f8); | ||||
|     return res.x; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a) | ||||
| { | ||||
| #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|     const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|     union { | ||||
|         __half2_raw h2r; | ||||
|         uint32_t ui32; | ||||
|     } tmp; | ||||
|     tmp.h2r.x.data = f2[0]; | ||||
|     tmp.h2r.y.data = f2[1]; | ||||
|     return tmp.ui32; | ||||
| #else | ||||
|     union { | ||||
|         uint16_t u16[2]; | ||||
|         uint32_t u32; | ||||
|     } tmp; | ||||
|  | ||||
|     tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a)); | ||||
|     tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U)); | ||||
|     return tmp.u32; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     union { | ||||
|         uint2 u32x2; | ||||
|         uint32_t u32[2]; | ||||
|     } tmp; | ||||
|     tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a); | ||||
|     tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U)); | ||||
|     return tmp.u32x2; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) | ||||
| { | ||||
|     union { | ||||
|         uint4 u64x2; | ||||
|         uint2 u64[2]; | ||||
|     } tmp; | ||||
|     tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x); | ||||
|     tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y); | ||||
|     return tmp.u64x2; | ||||
| } | ||||
|  | ||||
| using __nv_bfloat16 = __hip_bfloat16; | ||||
|  | ||||
| // fp8 -> __nv_bfloat16 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) | ||||
| { | ||||
|     hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|     float f{f8}; | ||||
|     return __float2bfloat16(f); | ||||
| } | ||||
|  | ||||
| using __nv_bfloat162 = __hip_bfloat162; | ||||
|  | ||||
| // fp8x2 -> __nv_bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) | ||||
| { | ||||
|     __nv_bfloat162 res; | ||||
|     res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); | ||||
|     res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> bf16_4_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     bf16_4_t res; | ||||
|     res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); | ||||
|     res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> bf16_8_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) | ||||
| { | ||||
|     bf16_4_t tmp1, tmp2; | ||||
|     tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); | ||||
|     tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); | ||||
|     bf16_8_t res; | ||||
|     res.x = tmp1.x; | ||||
|     res.y = tmp1.y; | ||||
|     res.z = tmp2.x; | ||||
|     res.w = tmp2.y; | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8 -> float | ||||
| template <> | ||||
| __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) | ||||
| { | ||||
|     hip_fp8 fp8{a, hip_fp8::from_bits()}; | ||||
|     return static_cast<float>(fp8); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> float2 | ||||
| template <> | ||||
| __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a) | ||||
| { | ||||
| #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|     float2 res; | ||||
|     const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|     res.x = f2[0]; | ||||
|     res.y = f2[1]; | ||||
|     return res; | ||||
| #else | ||||
|     float2 res; | ||||
|     res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a)); | ||||
|     res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U)); | ||||
|     return res; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     Float4_ res; | ||||
|     res.x = vec_conversion<float2, uint16_t>((uint16_t)a); | ||||
|     res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> float8 | ||||
| template <> | ||||
| __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) | ||||
| { | ||||
|     Float4_ tmp1, tmp2; | ||||
|     tmp1 = vec_conversion<Float4_, uint32_t>(a.x); | ||||
|     tmp2 = vec_conversion<Float4_, uint32_t>(a.y); | ||||
|     Float8_ res; | ||||
|     res.x = tmp1.x; | ||||
|     res.y = tmp1.y; | ||||
|     res.z = tmp2.x; | ||||
|     res.w = tmp2.y; | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // half -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a) | ||||
| { | ||||
|     __half_raw tmp; | ||||
|     tmp.x = a; | ||||
|  | ||||
|     hip_fp8 f8{static_cast<float>(tmp.data)}; | ||||
|     return f8.data; | ||||
| } | ||||
|  | ||||
| // bf16 -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) | ||||
| { | ||||
|     hip_fp8 res{__bfloat162float(a)}; | ||||
|     return res.data; | ||||
| } | ||||
|  | ||||
| // float -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) | ||||
| { | ||||
|     hip_fp8 f8(a); | ||||
|     return f8.data; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); | ||||
|     float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // float2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a) | ||||
| { | ||||
|     union { | ||||
|         half2 float16; | ||||
|         uint32_t uint32; | ||||
|     }; | ||||
|  | ||||
|     float16 = __float22half2_rn(a); | ||||
|     return uint32; | ||||
| } | ||||
|  | ||||
| // Float4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) | ||||
| { | ||||
|     uint2 b; | ||||
|     float2 val; | ||||
|     val.x = a.x.x; | ||||
|     val.y = a.x.y; | ||||
|     b.x = vec_conversion<uint32_t, float2>(val); | ||||
|  | ||||
|     val.x = a.y.x; | ||||
|     val.y = a.y.y; | ||||
|     b.y = vec_conversion<uint32_t, float2>(val); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| // Float4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) | ||||
| { | ||||
|     float4 b; | ||||
|     b.x = a.x.x; | ||||
|     b.y = a.x.y; | ||||
|     b.z = a.y.x; | ||||
|     b.w = a.y.y; | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| // Float8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) | ||||
| { | ||||
|     uint4 b; | ||||
|     b.x = vec_conversion<uint32_t, float2>(a.x); | ||||
|     b.y = vec_conversion<uint32_t, float2>(a.y); | ||||
|     b.z = vec_conversion<uint32_t, float2>(a.z); | ||||
|     b.w = vec_conversion<uint32_t, float2>(a.w); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| // float2 -> bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) | ||||
| { | ||||
|     __nv_bfloat162 b = __float22bfloat162_rn(a); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| // Float4 -> bfloat162x2 | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a) | ||||
| { | ||||
|     bf16_4_t b; | ||||
|     b.x = __float22bfloat162_rn(a.x); | ||||
|     b.y = __float22bfloat162_rn(a.y); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| // Float8 -> bfloat162x4 | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a) | ||||
| { | ||||
|     bf16_8_t b; | ||||
|     b.x = __float22bfloat162_rn(a.x); | ||||
|     b.y = __float22bfloat162_rn(a.y); | ||||
|     b.z = __float22bfloat162_rn(a.z); | ||||
|     b.w = __float22bfloat162_rn(a.w); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
|  | ||||
| /* Scaled and vectorized conversions, for data exchange between high and low precision domains | ||||
|  | ||||
|    Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) | ||||
|    s.t. | ||||
|      Quantize(HP / scale) => FP8 | ||||
|      Dequant(FP8) * scale =>  HP | ||||
|  | ||||
|  */ | ||||
|  | ||||
| // fp8 -> half | ||||
| template <> | ||||
| __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) | ||||
| { | ||||
|     hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|     __half_raw res; | ||||
|     res.data = static_cast<float>(f8) * scale; | ||||
|     return res.x; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale) | ||||
| { | ||||
| #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|     const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|     union { | ||||
|         __half2_raw h2r; | ||||
|         uint32_t ui32; | ||||
|     } tmp; | ||||
|     tmp.h2r.x.data = f2[0] * scale; | ||||
|     tmp.h2r.y.data = f2[1] * scale; | ||||
|     return tmp.ui32; | ||||
| #else | ||||
|     union { | ||||
|         uint16_t u16[2]; | ||||
|         uint32_t u32; | ||||
|     } tmp; | ||||
|  | ||||
|     tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale); | ||||
|     tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale); | ||||
|     return tmp.u32; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) | ||||
| { | ||||
|     union { | ||||
|         uint2 u32x2; | ||||
|         uint32_t u32[2]; | ||||
|     } tmp; | ||||
|     tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); | ||||
|     tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale); | ||||
|     return tmp.u32x2; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) | ||||
| { | ||||
|     union { | ||||
|         uint4 u64x2; | ||||
|         uint2 u64[2]; | ||||
|     } tmp; | ||||
|     tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale); | ||||
|     tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale); | ||||
|     return tmp.u64x2; | ||||
| } | ||||
|  | ||||
| using __nv_bfloat16 = __hip_bfloat16; | ||||
|  | ||||
| // fp8 -> __nv_bfloat16 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) | ||||
| { | ||||
|     hip_fp8 f8{a, hip_fp8::from_bits()}; | ||||
|     float f{f8}; | ||||
|     return __float2bfloat16(f * scale); | ||||
| } | ||||
|  | ||||
| using __nv_bfloat162 = __hip_bfloat162; | ||||
|  | ||||
| // fp8x2 -> __nv_bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) | ||||
| { | ||||
|     __nv_bfloat162 res; | ||||
|     res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); | ||||
|     res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> bf16_4_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale) | ||||
| { | ||||
|     bf16_4_t res; | ||||
|     res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); | ||||
|     res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> bf16_8_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) | ||||
| { | ||||
|     bf16_4_t tmp1, tmp2; | ||||
|     tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale); | ||||
|     tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale); | ||||
|     bf16_8_t res; | ||||
|     res.x = tmp1.x; | ||||
|     res.y = tmp1.y; | ||||
|     res.z = tmp2.x; | ||||
|     res.w = tmp2.y; | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8 -> float | ||||
| template <> | ||||
| __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale) | ||||
| { | ||||
|     hip_fp8 fp8{a, hip_fp8::from_bits()}; | ||||
|     return static_cast<float>(fp8) * scale; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> float2 | ||||
| template <> | ||||
| __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) | ||||
| { | ||||
| #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) | ||||
|     float2 res; | ||||
|     const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); | ||||
|     res.x = f2[0] * scale; | ||||
|     res.y = f2[1] * scale; | ||||
|     return res; | ||||
| #else | ||||
|     float2 res; | ||||
|     res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale); | ||||
|     res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale); | ||||
|     return res; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) | ||||
| { | ||||
|     Float4_ res; | ||||
|     res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale); | ||||
|     res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> float8 | ||||
| template <> | ||||
| __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) | ||||
| { | ||||
|     Float4_ tmp1, tmp2; | ||||
|     tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale); | ||||
|     tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale); | ||||
|     Float8_ res; | ||||
|     res.x = tmp1.x; | ||||
|     res.y = tmp1.y; | ||||
|     res.z = tmp2.x; | ||||
|     res.w = tmp2.y; | ||||
|     return res; | ||||
| } | ||||
|  | ||||
|  | ||||
| /* Quantize(HP / scale) => FP8 */ | ||||
|  | ||||
| // TODO(Hai): vectorized to add | ||||
|  | ||||
| // half -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) | ||||
| { | ||||
|     __half_raw tmp; | ||||
|     tmp.x = a; | ||||
|  | ||||
|     hip_fp8 f8{static_cast<float>(tmp.data)/scale}; | ||||
|     return f8.data; | ||||
| } | ||||
|  | ||||
| // bf16 -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale) | ||||
| { | ||||
|     hip_fp8 res{__bfloat162float(a)/scale}; | ||||
|     return res.data; | ||||
| } | ||||
|  | ||||
| // float -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) | ||||
| { | ||||
|     hip_fp8 f8(a/scale); | ||||
|     return f8.data; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) | ||||
| { | ||||
|     Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale); | ||||
|     float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| } | ||||
| } // namespace vllm | ||||
							
								
								
									
										124
									
								
								csrc/quantization/fp8/common.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								csrc/quantization/fp8/common.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,124 @@ | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
| #include <torch/extension.h> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
|  | ||||
| #include <cmath> | ||||
|  | ||||
| #include "cuda_compat.h" | ||||
| #include "dispatch_utils.h" | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { | ||||
|   float old; | ||||
|   old = (value >= 0) | ||||
|             ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) | ||||
|             : __uint_as_float( | ||||
|                   atomicMin((unsigned int*)addr, __float_as_uint(value))); | ||||
|  | ||||
|   return old; | ||||
| } | ||||
|  | ||||
| #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max() | ||||
|  | ||||
| template <typename scalar_t> | ||||
| __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( | ||||
|     const scalar_t val, const float scale) { | ||||
|   float x = static_cast<float>(val) / scale; | ||||
|   float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); | ||||
|   return static_cast<c10::Float8_e4m3fn>(r); | ||||
| } | ||||
|  | ||||
| // Compute the absolute maximum m of the input tensor and store | ||||
| // m / float8_e4m3::max() in *scale. Each thread block performs a | ||||
| // reduction tree and the memory in scale is atomically updated. | ||||
| // So to get the right answer, *scale needs to be initialized to | ||||
| // a value <= 0.0 and we need to wait for all thread blocks to | ||||
| // finish before consuming *scale. | ||||
| template <typename scalar_t> | ||||
| __global__ void segmented_max_reduction(float* __restrict__ scale, | ||||
|                                         const scalar_t* __restrict__ input, | ||||
|                                         int64_t num_elems) { | ||||
|   __shared__ float cache[1024]; | ||||
|   int i = blockDim.x * blockIdx.x + threadIdx.x; | ||||
|  | ||||
|   // First store maximum for all values processes by | ||||
|   // the current thread in cache[threadIdx.x] | ||||
|   scalar_t tmp = 0.0; | ||||
|   while (i < num_elems) { | ||||
|     float x = static_cast<float>(input[i]); | ||||
|     tmp = max(tmp, fabs(x)); | ||||
|     i += blockDim.x * gridDim.x; | ||||
|   } | ||||
|   cache[threadIdx.x] = tmp; | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // Now perform parallel reduction within the thread block | ||||
|   int ib = blockDim.x / 2; | ||||
|   while (ib != 0) { | ||||
|     if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { | ||||
|       cache[threadIdx.x] = cache[threadIdx.x + ib]; | ||||
|     } | ||||
|     __syncthreads(); | ||||
|     ib /= 2; | ||||
|   } | ||||
|   // Finally, since cache[0] contains the maximum for this thread block, | ||||
|   // atomically write the max to the target location | ||||
|   if (threadIdx.x == 0) { | ||||
|     atomicMaxFloat(scale, | ||||
|                    cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, | ||||
|                                         const scalar_t* __restrict__ input, | ||||
|                                         const float* __restrict__ scale, | ||||
|                                         int64_t num_elems) { | ||||
|   int i = blockDim.x * blockIdx.x + threadIdx.x; | ||||
|   while (i < num_elems) { | ||||
|     out[i] = scaled_fp8_conversion(input[i], *scale); | ||||
|     i += blockDim.x * gridDim.x; | ||||
|   } | ||||
| } | ||||
|  | ||||
| }  // namespace vllm | ||||
|  | ||||
| void static_scaled_fp8_quant(torch::Tensor& out,    // [..., d] | ||||
|                              torch::Tensor& input,  // [..., d] | ||||
|                              torch::Tensor& scale)  // [1] | ||||
| { | ||||
|   int64_t num_tokens = input.numel() / input.size(-1); | ||||
|   int64_t num_elems = input.numel(); | ||||
|   dim3 grid(num_tokens); | ||||
|   dim3 block(1024); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "scaled_fp8_quant_kernel", [&] { | ||||
|         vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|             out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(), | ||||
|             scale.data_ptr<float>(), num_elems); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void dynamic_scaled_fp8_quant(torch::Tensor& out,    // [..., d] | ||||
|                               torch::Tensor& input,  // [..., d] | ||||
|                               torch::Tensor& scale)  // [1] | ||||
| { | ||||
|   int64_t num_tokens = input.numel() / input.size(-1); | ||||
|   int64_t num_elems = input.numel(); | ||||
|   dim3 grid(num_tokens); | ||||
|   dim3 block(1024); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "scaled_fp8_quant_kernel", [&] { | ||||
|         vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>( | ||||
|             scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems); | ||||
|         vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|             out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(), | ||||
|             scale.data_ptr<float>(), num_elems); | ||||
|       }); | ||||
| } | ||||
| @ -1,126 +0,0 @@ | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
| #include <torch/extension.h> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
|  | ||||
| #include <cmath> | ||||
|  | ||||
| #include "cuda_compat.h" | ||||
| #include "dispatch_utils.h" | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { | ||||
|     float old; | ||||
|     old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : | ||||
|          __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); | ||||
|  | ||||
|     return old; | ||||
| } | ||||
|  | ||||
| // Compute the absolute maximum m of the input tensor and store | ||||
| // m / float8_e4m3::max() in *scale. Each thread block performs a | ||||
| // reduction tree and the memory in scale is atomically updated. | ||||
| // So to get the right answer, *scale needs to be initialized to | ||||
| // a value <= 0.0 and we need to wait for all thread blocks to | ||||
| // finish before consuming *scale. | ||||
| template<typename scalar_t> | ||||
| __global__ void segmented_max_reduction( | ||||
|   float* __restrict__ scale, | ||||
|   const scalar_t* __restrict__ input, | ||||
|   int64_t num_elems) { | ||||
|   __shared__ float cache[1024]; | ||||
|   int i = blockDim.x * blockIdx.x + threadIdx.x; | ||||
|  | ||||
|   // First store maximum for all values processes by | ||||
|   // the current thread in cache[threadIdx.x] | ||||
|   scalar_t tmp = 0.0; | ||||
|   while (i < num_elems) { | ||||
|     float x = static_cast<float>(input[i]); | ||||
|     tmp = max(tmp, fabs(x)); | ||||
|     i += blockDim.x * gridDim.x; | ||||
|   } | ||||
|   cache[threadIdx.x] = tmp; | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // Now perform parallel reduction within the thread block | ||||
|   int ib = blockDim.x / 2; | ||||
|   while (ib != 0) { | ||||
|     if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { | ||||
|         cache[threadIdx.x] = cache[threadIdx.x + ib]; | ||||
|     } | ||||
|     __syncthreads(); | ||||
|     ib /= 2; | ||||
|   } | ||||
|   // Finally, since cache[0] contains the maximum for this thread block, | ||||
|   // atomically write the max to the target location | ||||
|   if (threadIdx.x == 0) { | ||||
|     atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| __global__ void scaled_fp8_quant_kernel( | ||||
|   c10::Float8_e4m3fn* __restrict__ out, | ||||
|   const scalar_t* __restrict__ input, | ||||
|   const float* __restrict__ scale, | ||||
|   int64_t num_elems) { | ||||
|   int i = blockDim.x * blockIdx.x + threadIdx.x; | ||||
|   while (i < num_elems) { | ||||
|     out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale); | ||||
|     i += blockDim.x * gridDim.x; | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace vllm | ||||
|  | ||||
| void static_scaled_fp8_quant( | ||||
|   torch::Tensor& out,      // [..., d] | ||||
|   torch::Tensor& input,    // [..., d] | ||||
|   torch::Tensor& scale)    // [1] | ||||
| { | ||||
|   int64_t num_tokens = input.numel() / input.size(-1); | ||||
|   int64_t num_elems = input.numel(); | ||||
|   dim3 grid(num_tokens); | ||||
|   dim3 block(1024); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|     input.scalar_type(), | ||||
|     "scaled_fp8_quant_kernel", | ||||
|     [&] { | ||||
|       vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|         out.data_ptr<c10::Float8_e4m3fn>(), | ||||
|         input.data_ptr<scalar_t>(), | ||||
|         scale.data_ptr<float>(), | ||||
|         num_elems); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void dynamic_scaled_fp8_quant( | ||||
|   torch::Tensor& out,      // [..., d] | ||||
|   torch::Tensor& input,    // [..., d] | ||||
|   torch::Tensor& scale)    // [1] | ||||
| { | ||||
|   int64_t num_tokens = input.numel() / input.size(-1); | ||||
|   int64_t num_elems = input.numel(); | ||||
|   dim3 grid(num_tokens); | ||||
|   dim3 block(1024); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|     input.scalar_type(), | ||||
|     "scaled_fp8_quant_kernel", | ||||
|     [&] { | ||||
|       vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>( | ||||
|         scale.data_ptr<float>(), | ||||
|         input.data_ptr<scalar_t>(), | ||||
|         num_elems); | ||||
|       vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||
|         out.data_ptr<c10::Float8_e4m3fn>(), | ||||
|         input.data_ptr<scalar_t>(), | ||||
|         scale.data_ptr<float>(), | ||||
|         num_elems); | ||||
|       }); | ||||
| } | ||||
|  | ||||
							
								
								
									
										570
									
								
								csrc/quantization/fp8/nvidia/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										570
									
								
								csrc/quantization/fp8/nvidia/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,570 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "../../../attention/attention_dtypes.h" | ||||
| #include <assert.h> | ||||
| #include <float.h> | ||||
| #include <stdint.h> | ||||
| #include <type_traits> | ||||
|  | ||||
| namespace vllm { | ||||
| #ifndef USE_ROCM | ||||
|  | ||||
| namespace fp8 { | ||||
|   #ifdef ENABLE_FP8 | ||||
|  | ||||
|     #if 0  // Disable the following code to reduce the binary size. | ||||
| template <typename Tout, typename Tin> | ||||
| __inline__ __device__ Tout | ||||
| vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   return x; | ||||
| } | ||||
|  | ||||
| // fp8 -> half | ||||
| template <> | ||||
| __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>( | ||||
|     const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); | ||||
|   return res.x; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>( | ||||
|     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   union { | ||||
|     uint16_t u16[2]; | ||||
|     uint32_t u32; | ||||
|   } tmp; | ||||
|   __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); | ||||
|   tmp.u16[0] = res.x; | ||||
|   tmp.u16[1] = res.y; | ||||
|   return tmp.u32; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>( | ||||
|     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   union { | ||||
|     uint2 u32x2; | ||||
|     uint32_t u32[2]; | ||||
|   } tmp; | ||||
|   tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type); | ||||
|   tmp.u32[1] = | ||||
|       vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type); | ||||
|   return tmp.u32x2; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, uint2>( | ||||
|     const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   union { | ||||
|     uint4 u64x2; | ||||
|     uint2 u64[2]; | ||||
|   } tmp; | ||||
|   tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type); | ||||
|   tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type); | ||||
|   return tmp.u64x2; | ||||
| } | ||||
|  | ||||
| // fp8 -> __nv_bfloat16 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( | ||||
|     const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   // Note there is no direct convert function from fp8 to bf16. | ||||
|   // fp8 -> half | ||||
|   __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); | ||||
|   // half -> float -> bf16 | ||||
|   float tmp = half_to_float(res.x); | ||||
|   return __float2bfloat16(tmp); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> __nv_bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( | ||||
|     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __nv_bfloat162 res; | ||||
|   res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); | ||||
|   res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> bf16_4_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>( | ||||
|     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   bf16_4_t res; | ||||
|   res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); | ||||
|   res.y = | ||||
|       vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> bf16_8_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>( | ||||
|     const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   bf16_4_t tmp1, tmp2; | ||||
|   tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type); | ||||
|   tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type); | ||||
|   bf16_8_t res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8 -> float | ||||
| template <> | ||||
| __inline__ __device__ float | ||||
| vec_conversion<float, uint8_t>(const uint8_t &a, | ||||
|                                const __nv_fp8_interpretation_t fp8_type) { | ||||
|   // fp8 -> half | ||||
|   uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type); | ||||
|   // half -> float | ||||
|   return half_to_float(tmp); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> float2 | ||||
| template <> | ||||
| __inline__ __device__ float2 vec_conversion<float2, uint16_t>( | ||||
|     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   // fp8x2 -> half2 | ||||
|   uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type); | ||||
|   // half2 -> float2 | ||||
|   return half2_to_float2(tmp); | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>( | ||||
|     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   Float4_ res; | ||||
|   res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type); | ||||
|   res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> float8 | ||||
| template <> | ||||
| __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>( | ||||
|     const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   Float4_ tmp1, tmp2; | ||||
|   tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type); | ||||
|   tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type); | ||||
|   Float8_ res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // half -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>( | ||||
|     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __half_raw tmp; | ||||
|   tmp.x = a; | ||||
|   __nv_fp8_storage_t res = | ||||
|       __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); | ||||
|   return (uint8_t)res; | ||||
| } | ||||
|  | ||||
| // bf16 -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>( | ||||
|     const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|       #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|   assert(false); | ||||
|       #else | ||||
|   __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( | ||||
|       __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); | ||||
|   return (uint8_t)res; | ||||
|       #endif | ||||
| } | ||||
|  | ||||
| // float -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, float>( | ||||
|     const float &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); | ||||
|   return (uint8_t)res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 vec_conversion<float4, uint32_t>( | ||||
|     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type); | ||||
|   float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>( | ||||
|     const float2 &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   union { | ||||
|     half2 float16; | ||||
|     uint32_t uint32; | ||||
|   }; | ||||
|  | ||||
|   float16 = __float22half2_rn(a); | ||||
|   return uint32; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, Float4_>( | ||||
|     const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   uint2 b; | ||||
|   float2 val; | ||||
|   val.x = a.x.x; | ||||
|   val.y = a.x.y; | ||||
|   b.x = vec_conversion<uint32_t, float2>(val, fp8_type); | ||||
|  | ||||
|   val.x = a.y.x; | ||||
|   val.y = a.y.y; | ||||
|   b.y = vec_conversion<uint32_t, float2>(val, fp8_type); | ||||
|  | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __inline__ __device__ float4 vec_conversion<float4, Float4_>( | ||||
|     const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   float4 b; | ||||
|   b.x = a.x.x; | ||||
|   b.y = a.x.y; | ||||
|   b.z = a.y.x; | ||||
|   b.w = a.y.y; | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, Float8_>( | ||||
|     const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   uint4 b; | ||||
|   b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type); | ||||
|   b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type); | ||||
|   b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type); | ||||
|   b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( | ||||
|     const float2 &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __nv_bfloat162 b; | ||||
|   from_float(b, a); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>( | ||||
|     const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   bf16_4_t b; | ||||
|   from_float(b, a); | ||||
|   return b; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>( | ||||
|     const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   bf16_8_t b; | ||||
|   from_float(b, a); | ||||
|   return b; | ||||
| } | ||||
|     #endif | ||||
|  | ||||
| /* Scaled and vectorized conversions, for data exchange between high and low | ||||
|    precision domains Convention of the scale in API, e.g: FP8_data = | ||||
|    Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 | ||||
|      Dequant(FP8) * scale =>  HP | ||||
|  */ | ||||
|  | ||||
| template <typename Tout, typename Tin> | ||||
| __inline__ __device__ Tout scaled_vec_conversion( | ||||
|     const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { | ||||
|   return x; | ||||
| } | ||||
|  | ||||
| // fp8 -> half | ||||
| template <> | ||||
| __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>( | ||||
|     const uint8_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); | ||||
|   return float_to_half(half_to_float(tmp.x) * scale); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> half2 | ||||
| template <> | ||||
| __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>( | ||||
|     const uint16_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   union { | ||||
|     uint16_t u16[2]; | ||||
|     uint32_t u32; | ||||
|   } tmp; | ||||
|   __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); | ||||
|   tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); | ||||
|   tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); | ||||
|   return tmp.u32; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> half2x2 | ||||
| template <> | ||||
| __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>( | ||||
|     const uint32_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   union { | ||||
|     uint2 u32x2; | ||||
|     uint32_t u32[2]; | ||||
|   } tmp; | ||||
|   tmp.u32[0] = | ||||
|       scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type); | ||||
|   tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), | ||||
|                                                          scale, fp8_type); | ||||
|   return tmp.u32x2; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> half2x4 | ||||
| template <> | ||||
| __inline__ __device__ uint4 | ||||
| scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale, | ||||
|                                     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   union { | ||||
|     uint4 u64x2; | ||||
|     uint2 u64[2]; | ||||
|   } tmp; | ||||
|   tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type); | ||||
|   tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type); | ||||
|   return tmp.u64x2; | ||||
| } | ||||
|  | ||||
| // fp8 -> __nv_bfloat16 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat16 | ||||
| scaled_vec_conversion<__nv_bfloat16, uint8_t>( | ||||
|     const uint8_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   // Note there is no direct convert function from fp8 to bf16. | ||||
|   // fp8 -> half | ||||
|   __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); | ||||
|   // half -> float -> bf16 | ||||
|   float tmp = half_to_float(res.x); | ||||
|   return __float2bfloat16(tmp * scale); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> __nv_bfloat162 | ||||
| template <> | ||||
| __inline__ __device__ __nv_bfloat162 | ||||
| scaled_vec_conversion<__nv_bfloat162, uint16_t>( | ||||
|     const uint16_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __nv_bfloat162 res; | ||||
|   res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, | ||||
|                                                         fp8_type); | ||||
|   res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), | ||||
|                                                         scale, fp8_type); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> bf16_4_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>( | ||||
|     const uint32_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   bf16_4_t res; | ||||
|   res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, | ||||
|                                                           fp8_type); | ||||
|   res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), | ||||
|                                                           scale, fp8_type); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> bf16_8_t | ||||
| template <> | ||||
| __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>( | ||||
|     const uint2& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   bf16_4_t tmp1, tmp2; | ||||
|   tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type); | ||||
|   tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type); | ||||
|   bf16_8_t res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8 -> float | ||||
| template <> | ||||
| __inline__ __device__ float scaled_vec_conversion<float, uint8_t>( | ||||
|     const uint8_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   // fp8 -> half | ||||
|   __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); | ||||
|   uint16_t tmp = res.x; | ||||
|  | ||||
|   // half -> float | ||||
|   return half_to_float(tmp) * scale; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> float2 | ||||
| template <> | ||||
| __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>( | ||||
|     const uint16_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   // fp8x2 -> half2 | ||||
|   uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type); | ||||
|   // half2 -> float2 | ||||
|   return half2_to_float2(tmp); | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>( | ||||
|     const uint32_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   Float4_ res; | ||||
|   res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type); | ||||
|   res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, | ||||
|                                                   fp8_type); | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> float8 | ||||
| template <> | ||||
| __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>( | ||||
|     const uint2& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   Float4_ tmp1, tmp2; | ||||
|   tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type); | ||||
|   tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type); | ||||
|   Float8_ res; | ||||
|   res.x = tmp1.x; | ||||
|   res.y = tmp1.y; | ||||
|   res.z = tmp2.x; | ||||
|   res.w = tmp2.y; | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| // half -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>( | ||||
|     const uint16_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __nv_fp8_storage_t res = | ||||
|       __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); | ||||
|   return (uint8_t)res; | ||||
| } | ||||
|  | ||||
| // bf16 -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( | ||||
|     const __nv_bfloat16& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|     #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|   assert(false); | ||||
|     #else | ||||
|   __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, | ||||
|                                                  __NV_SATFINITE, fp8_type); | ||||
|   return (uint8_t)res; | ||||
|     #endif | ||||
| } | ||||
|  | ||||
| // float -> fp8 | ||||
| template <> | ||||
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>( | ||||
|     const float& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   __nv_fp8_storage_t res = | ||||
|       __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); | ||||
|   return (uint8_t)res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template <> | ||||
| __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>( | ||||
|     const uint32_t& a, const float scale, | ||||
|     const __nv_fp8_interpretation_t fp8_type) { | ||||
|   Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type); | ||||
|   float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||
|   return res; | ||||
| } | ||||
|   #endif  // ENABLE_FP8 | ||||
|  | ||||
| template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> | ||||
| __inline__ __device__ Tout convert(const Tin& x) { | ||||
|   #if 0  // Disable the following code to reduce the binary size. | ||||
|   if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { | ||||
|     return vec_conversion<Tout, Tin>(x, __NV_E4M3); | ||||
|   } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { | ||||
|     return vec_conversion<Tout, Tin>(x, __NV_E5M2); | ||||
|   } | ||||
|   #endif | ||||
|   assert(false); | ||||
| } | ||||
|  | ||||
| template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> | ||||
| __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { | ||||
|   #ifdef ENABLE_FP8 | ||||
|   if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { | ||||
|     return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3); | ||||
|   } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { | ||||
|     return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2); | ||||
|   } | ||||
|   #endif | ||||
|   assert(false); | ||||
| } | ||||
|  | ||||
|   // The following macro is used to dispatch the conversion function based on | ||||
|   // the data type of the key and value cache. The FN is a macro that calls a | ||||
|   // function with template<typename scalar_t, typename cache_t, | ||||
|   // Fp8KVCacheDataType kv_dt>. | ||||
|   #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \ | ||||
|     if (KV_DTYPE == "auto") {                                                  \ | ||||
|       if (SRC_DTYPE == at::ScalarType::Float) {                                \ | ||||
|         FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \ | ||||
|       } else if (SRC_DTYPE == at::ScalarType::Half) {                          \ | ||||
|         FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \ | ||||
|       } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \ | ||||
|         FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \ | ||||
|       } else {                                                                 \ | ||||
|         TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ | ||||
|       }                                                                        \ | ||||
|     } else {                                                                   \ | ||||
|       if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \ | ||||
|         if (SRC_DTYPE == at::ScalarType::Float) {                              \ | ||||
|           FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \ | ||||
|         } else if (SRC_DTYPE == at::ScalarType::Half) {                        \ | ||||
|           FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \ | ||||
|         } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \ | ||||
|           FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \ | ||||
|         } else {                                                               \ | ||||
|           TORCH_CHECK(false,                                                   \ | ||||
|                       "Unsupported input type of kv cache: ", SRC_DTYPE);      \ | ||||
|         }                                                                      \ | ||||
|       } else if (KV_DTYPE == "fp8_e5m2") {                                     \ | ||||
|         if (SRC_DTYPE == at::ScalarType::Float) {                              \ | ||||
|           FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);              \ | ||||
|         } else if (SRC_DTYPE == at::ScalarType::Half) {                        \ | ||||
|           FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);           \ | ||||
|         } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \ | ||||
|           FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);      \ | ||||
|         } else {                                                               \ | ||||
|           TORCH_CHECK(false,                                                   \ | ||||
|                       "Unsupported input type of kv cache: ", SRC_DTYPE);      \ | ||||
|         }                                                                      \ | ||||
|       } else {                                                                 \ | ||||
|         TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \ | ||||
|       }                                                                        \ | ||||
|     } | ||||
|  | ||||
| }  // namespace fp8 | ||||
| #endif  // not USE_ROCM | ||||
| }  // namespace vllm | ||||
| @ -1,277 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <assert.h> | ||||
| #include <stdint.h> | ||||
| #include <float.h> | ||||
| #include <type_traits> | ||||
| #include "../../attention/attention_dtypes.h" | ||||
| #include "../../attention/dtype_float32.cuh" | ||||
| #include "../../attention/dtype_float16.cuh" | ||||
| #include "../../attention/dtype_bfloat16.cuh" | ||||
|  | ||||
|  | ||||
| namespace vllm { | ||||
| #ifdef ENABLE_FP8_E5M2 | ||||
| namespace fp8_e5m2_unscaled { | ||||
|  | ||||
| template<typename Tout, typename Tin> | ||||
| __inline__ __device__ Tout vec_conversion(const Tin& x) | ||||
| { | ||||
|     return x; | ||||
| } | ||||
|  | ||||
| // fp8 -> half | ||||
| template<> | ||||
| __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a) | ||||
| { | ||||
|     __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); | ||||
|     return res.x; | ||||
| } | ||||
|  | ||||
| // fp8x2 -> half2 | ||||
| template<> | ||||
| __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a) | ||||
| { | ||||
|     union { | ||||
|         uint16_t u16[2]; | ||||
|         uint32_t u32; | ||||
|     } tmp; | ||||
|     __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); | ||||
|     tmp.u16[0] = res.x; | ||||
|     tmp.u16[1] = res.y; | ||||
|     return tmp.u32; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> half2x2 | ||||
| template<> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     union { | ||||
|         uint2    u32x2; | ||||
|         uint32_t u32[2]; | ||||
|     } tmp; | ||||
|     tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a); | ||||
|     tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U)); | ||||
|     return tmp.u32x2; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> half2x4 | ||||
| template<> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) | ||||
| { | ||||
|     union { | ||||
|         uint4 u64x2; | ||||
|         uint2 u64[2]; | ||||
|     } tmp; | ||||
|     tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x); | ||||
|     tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y); | ||||
|     return tmp.u64x2; | ||||
| } | ||||
|  | ||||
| // fp8 -> __nv_bfloat16 | ||||
| template<> | ||||
| __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) | ||||
| { | ||||
|     // Note there is no direct convert function from fp8 to bf16. | ||||
|     // fp8 -> half | ||||
|     __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); | ||||
|     // half -> float -> bf16 | ||||
|     float tmp = half_to_float(res.x); | ||||
|     return __float2bfloat16(tmp); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> __nv_bfloat162 | ||||
| template<> | ||||
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) | ||||
| { | ||||
|     __nv_bfloat162 res; | ||||
|     res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); | ||||
|     res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> bf16_4_t | ||||
| template<> | ||||
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     bf16_4_t res; | ||||
|     res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); | ||||
|     res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> bf16_8_t | ||||
| template<> | ||||
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) | ||||
| { | ||||
|     bf16_4_t tmp1, tmp2; | ||||
|     tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); | ||||
|     tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); | ||||
|     bf16_8_t res; | ||||
|     res.x = tmp1.x; | ||||
|     res.y = tmp1.y; | ||||
|     res.z = tmp2.x; | ||||
|     res.w = tmp2.y; | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8 -> float | ||||
| template<> | ||||
| __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) | ||||
| { | ||||
|     // fp8 -> half | ||||
|     uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a); | ||||
|     // half -> float | ||||
|     return half_to_float(tmp); | ||||
| } | ||||
|  | ||||
| // fp8x2 -> float2 | ||||
| template<> | ||||
| __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a) | ||||
| { | ||||
|     // fp8x2 -> half2 | ||||
|     uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a); | ||||
|     // half2 -> float2 | ||||
|     return half2_to_float2(tmp); | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template<> | ||||
| __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     Float4_ res; | ||||
|     res.x = vec_conversion<float2, uint16_t>((uint16_t)a); | ||||
|     res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| // fp8x8 -> float8 | ||||
| template<> | ||||
| __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) | ||||
| { | ||||
|     Float4_ tmp1, tmp2; | ||||
|     tmp1 = vec_conversion<Float4_, uint32_t>(a.x); | ||||
|     tmp2 = vec_conversion<Float4_, uint32_t>(a.y); | ||||
|     Float8_ res; | ||||
|     res.x = tmp1.x; | ||||
|     res.y = tmp1.y; | ||||
|     res.z = tmp2.x; | ||||
|     res.w = tmp2.y; | ||||
|     return res; | ||||
| } | ||||
|  | ||||
|  | ||||
| // half -> fp8 | ||||
| template<> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a) | ||||
| { | ||||
|     __half_raw tmp; | ||||
|     tmp.x = a; | ||||
|     __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); | ||||
|     return (uint8_t)res; | ||||
| } | ||||
|  | ||||
| // bf16 -> fp8 | ||||
| template<> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) | ||||
| { | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|     assert(false); | ||||
| #else | ||||
|     __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); | ||||
|     return (uint8_t)res; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // float -> fp8 | ||||
| template<> | ||||
| __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) | ||||
| { | ||||
|     __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); | ||||
|     return (uint8_t)res; | ||||
| } | ||||
|  | ||||
| // fp8x4 -> float4 | ||||
| template<> | ||||
| __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a) | ||||
| { | ||||
|     Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); | ||||
|     float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
|  | ||||
| template<> | ||||
| __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a) | ||||
| { | ||||
|     union { | ||||
|         half2    float16; | ||||
|         uint32_t uint32; | ||||
|     }; | ||||
|  | ||||
|     float16 = __float22half2_rn(a); | ||||
|     return uint32; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) | ||||
| { | ||||
|     uint2  b; | ||||
|     float2 val; | ||||
|     val.x = a.x.x; | ||||
|     val.y = a.x.y; | ||||
|     b.x   = vec_conversion<uint32_t, float2>(val); | ||||
|  | ||||
|     val.x = a.y.x; | ||||
|     val.y = a.y.y; | ||||
|     b.y   = vec_conversion<uint32_t, float2>(val); | ||||
|  | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) | ||||
| { | ||||
|     float4 b; | ||||
|     b.x = a.x.x; | ||||
|     b.y = a.x.y; | ||||
|     b.z = a.y.x; | ||||
|     b.w = a.y.y; | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) | ||||
| { | ||||
|     uint4 b; | ||||
|     b.x = vec_conversion<uint32_t, float2>(a.x); | ||||
|     b.y = vec_conversion<uint32_t, float2>(a.y); | ||||
|     b.z = vec_conversion<uint32_t, float2>(a.z); | ||||
|     b.w = vec_conversion<uint32_t, float2>(a.w); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { | ||||
|     __nv_bfloat162 b; | ||||
|     from_float(b, a); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) { | ||||
|     bf16_4_t b; | ||||
|     from_float(b, a); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) { | ||||
|     bf16_8_t b; | ||||
|     from_float(b, a); | ||||
|     return b; | ||||
| } | ||||
|  | ||||
| } // namespace fp8_e5m2_unscaled | ||||
| #endif // ENABLE_FP8_E5M2 | ||||
| } // namespace vllm | ||||
| @ -9,54 +9,54 @@ namespace vllm { | ||||
| namespace gptq { | ||||
| // atomicAdd for half types, to support CC < 7.x | ||||
|  | ||||
| __device__ __forceinline__ void atomicAdd_half(half* address, half val) | ||||
| { | ||||
|     unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); | ||||
|     unsigned int old = *address_as_ui; | ||||
|     unsigned int assumed; | ||||
| __device__ __forceinline__ void atomicAdd_half(half* address, half val) { | ||||
|   unsigned int* address_as_ui = | ||||
|       (unsigned int*)((char*)address - ((size_t)address & 2)); | ||||
|   unsigned int old = *address_as_ui; | ||||
|   unsigned int assumed; | ||||
|  | ||||
|     do | ||||
|     { | ||||
|         assumed = old; | ||||
|         __half_raw hsum; | ||||
|         hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); | ||||
|         half tmpres = __hadd(hsum, val); | ||||
|         hsum = __half_raw(tmpres); | ||||
|         old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; | ||||
|         old = atomicCAS(address_as_ui, assumed, old); | ||||
|     } | ||||
|     while (assumed != old); | ||||
|   do { | ||||
|     assumed = old; | ||||
|     __half_raw hsum; | ||||
|     hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); | ||||
|     half tmpres = __hadd(hsum, val); | ||||
|     hsum = __half_raw(tmpres); | ||||
|     old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) | ||||
|                               : (old & 0xffff0000) | hsum.x; | ||||
|     old = atomicCAS(address_as_ui, assumed, old); | ||||
|   } while (assumed != old); | ||||
| } | ||||
|  | ||||
| // atomicAdd for half2 types | ||||
|  | ||||
| __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) | ||||
| { | ||||
|     unsigned int* address_as_ui = (unsigned int*)address; | ||||
|     unsigned int old = *address_as_ui; | ||||
|     unsigned int assumed; | ||||
|     do | ||||
|     { | ||||
|         assumed = old; | ||||
|         half2 old_val = *((half2*)&old); | ||||
|         half2 new_val = __hadd2(old_val, val); | ||||
|         old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); | ||||
|     } | ||||
|     while (assumed != old); | ||||
| __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { | ||||
|   unsigned int* address_as_ui = (unsigned int*)address; | ||||
|   unsigned int old = *address_as_ui; | ||||
|   unsigned int assumed; | ||||
|   do { | ||||
|     assumed = old; | ||||
|     half2 old_val = *((half2*)&old); | ||||
|     half2 new_val = __hadd2(old_val, val); | ||||
|     old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); | ||||
|   } while (assumed != old); | ||||
| } | ||||
|  | ||||
| // | ||||
|  | ||||
| #if defined(__CUDA_ARCH__) || defined(USE_ROCM) | ||||
| #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) | ||||
|   #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) | ||||
|  | ||||
| __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } | ||||
| __device__ __forceinline__ void atomicAdd(half* address, half val) { | ||||
|   atomicAdd_half(address, val); | ||||
| } | ||||
|  | ||||
| #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) | ||||
| __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } | ||||
| #endif | ||||
|     #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) | ||||
| __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { | ||||
|   atomicAdd_half2(address, val); | ||||
| } | ||||
|     #endif | ||||
|  | ||||
| #endif | ||||
|   #endif | ||||
| #endif | ||||
|  | ||||
| }  // namespace gptq | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| /* | ||||
| Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama | ||||
| Adapted from https://github.com/turboderp/exllamav2 and | ||||
| https://github.com/turboderp/exllama | ||||
| */ | ||||
|  | ||||
| #ifndef _matrix_view_cuh | ||||
| @ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo | ||||
| namespace vllm { | ||||
| namespace gptq { | ||||
|  | ||||
| class MatrixView_half | ||||
| { | ||||
| public: | ||||
|     const half* data; | ||||
|     const int height; | ||||
|     const int width; | ||||
| class MatrixView_half { | ||||
|  public: | ||||
|   const half* data; | ||||
|   const int height; | ||||
|   const int width; | ||||
|  | ||||
|     __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) | ||||
|         : data(data), height(height), width(width) | ||||
|     { } | ||||
|   __device__ __forceinline__ MatrixView_half(const half* data, const int height, | ||||
|                                              const int width) | ||||
|       : data(data), height(height), width(width) {} | ||||
|  | ||||
|     __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } | ||||
|     __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } | ||||
|     __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } | ||||
|     __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } | ||||
|   __device__ __forceinline__ half item(int row, int column) const { | ||||
|     return data[row * width + column]; | ||||
|   } | ||||
|   __device__ __forceinline__ half2 item_half2(int row, int column) const { | ||||
|     return ((half2*)data)[(row * width + column) / 2]; | ||||
|   } | ||||
|   __device__ __forceinline__ half2 item_half2half2(int row, int column) const { | ||||
|     return __half2half2(data[row * width + column]); | ||||
|   } | ||||
|   __device__ __forceinline__ const half* item_ptr(int row, int column) const { | ||||
|     return &data[row * width + column]; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const | ||||
|     { | ||||
|         half2* ptr = (half2*) item_ptr(row, column); | ||||
|         half2 i01 = ptr[0]; | ||||
|         half2 i23 = ptr[1]; | ||||
|         items[0] = __low2half(i01); | ||||
|         items[1] = __high2half(i01); | ||||
|         items[2] = __low2half(i23); | ||||
|         items[3] = __high2half(i23); | ||||
|     } | ||||
|     __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const | ||||
|     { | ||||
|         half2* ptr = (half2*)item_ptr(row, column); | ||||
|         half2 i01 = ptr[0]; | ||||
|         half2 i23 = ptr[1]; | ||||
|         items[0] = __half2float(__low2half(i01)); | ||||
|         items[1] = __half2float(__high2half(i01)); | ||||
|         items[2] = __half2float(__low2half(i23)); | ||||
|         items[3] = __half2float(__high2half(i23)); | ||||
|     } | ||||
|   __device__ __forceinline__ void item4(half (&items)[4], int row, | ||||
|                                         int column) const { | ||||
|     half2* ptr = (half2*)item_ptr(row, column); | ||||
|     half2 i01 = ptr[0]; | ||||
|     half2 i23 = ptr[1]; | ||||
|     items[0] = __low2half(i01); | ||||
|     items[1] = __high2half(i01); | ||||
|     items[2] = __low2half(i23); | ||||
|     items[3] = __high2half(i23); | ||||
|   } | ||||
|   __device__ __forceinline__ void item4_f(float (&items)[4], int row, | ||||
|                                           int column) const { | ||||
|     half2* ptr = (half2*)item_ptr(row, column); | ||||
|     half2 i01 = ptr[0]; | ||||
|     half2 i23 = ptr[1]; | ||||
|     items[0] = __half2float(__low2half(i01)); | ||||
|     items[1] = __half2float(__high2half(i01)); | ||||
|     items[2] = __half2float(__low2half(i23)); | ||||
|     items[3] = __half2float(__high2half(i23)); | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const | ||||
|     { | ||||
|         half2* ptr = (half2*)item_ptr(row, column); | ||||
|         half2 i01 = ptr[0]; | ||||
|         half2 i23 = ptr[1]; | ||||
|         items[0] = __half2half2(__low2half(i01)); | ||||
|         items[1] = __half2half2(__high2half(i01)); | ||||
|         items[2] = __half2half2(__low2half(i23)); | ||||
|         items[3] = __half2half2(__high2half(i23)); | ||||
|     } | ||||
|   __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, | ||||
|                                            int column) const { | ||||
|     half2* ptr = (half2*)item_ptr(row, column); | ||||
|     half2 i01 = ptr[0]; | ||||
|     half2 i23 = ptr[1]; | ||||
|     items[0] = __half2half2(__low2half(i01)); | ||||
|     items[1] = __half2half2(__high2half(i01)); | ||||
|     items[2] = __half2half2(__low2half(i23)); | ||||
|     items[3] = __half2half2(__high2half(i23)); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| class MatrixView_half_rw | ||||
| { | ||||
| public: | ||||
|     half* data; | ||||
|     const int height; | ||||
|     const int width; | ||||
| class MatrixView_half_rw { | ||||
|  public: | ||||
|   half* data; | ||||
|   const int height; | ||||
|   const int width; | ||||
|  | ||||
|     __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) | ||||
|         : data(data), height(height), width(width) | ||||
|     { } | ||||
|   __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, | ||||
|                                                 const int width) | ||||
|       : data(data), height(height), width(width) {} | ||||
|  | ||||
|     __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } | ||||
|     __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } | ||||
|     __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } | ||||
|     __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } | ||||
|     __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } | ||||
|   __device__ __forceinline__ half item(int row, int column) const { | ||||
|     return data[row * width + column]; | ||||
|   } | ||||
|   __device__ __forceinline__ half2 item_half2(int row, int column) const { | ||||
|     return ((half2*)data)[(row * width + column) / 2]; | ||||
|   } | ||||
|   __device__ __forceinline__ half* item_ptr(int row, int column) { | ||||
|     return &data[row * width + column]; | ||||
|   } | ||||
|   __device__ __forceinline__ void set(int row, int column, half value) { | ||||
|     data[row * width + column] = value; | ||||
|   } | ||||
|   __device__ __forceinline__ void set_half2(int row, int column, half2 value) { | ||||
|     ((half2*)data)[(row * width + column) / 2] = value; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) | ||||
|     { | ||||
|         half2 v01 = __halves2half2(v0, v1); | ||||
|         half2 v23 = __halves2half2(v2, v3); | ||||
|         half2* ptr = (half2*) item_ptr(row, column); | ||||
|         ptr[0] = v01; | ||||
|         ptr[1] = v23; | ||||
|     } | ||||
|   __device__ __forceinline__ void set4(int row, int column, half v0, half v1, | ||||
|                                        half v2, half v3) { | ||||
|     half2 v01 = __halves2half2(v0, v1); | ||||
|     half2 v23 = __halves2half2(v2, v3); | ||||
|     half2* ptr = (half2*)item_ptr(row, column); | ||||
|     ptr[0] = v01; | ||||
|     ptr[1] = v23; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| class MatrixView_q4_row | ||||
| { | ||||
| public: | ||||
|     const uint32_t* data; | ||||
|     const int height; | ||||
|     const int width; | ||||
| class MatrixView_q4_row { | ||||
|  public: | ||||
|   const uint32_t* data; | ||||
|   const int height; | ||||
|   const int width; | ||||
|  | ||||
|     __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) | ||||
|         : data(data), height(height), width(width) | ||||
|     { } | ||||
|   __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, | ||||
|                                                const int height, | ||||
|                                                const int width) | ||||
|       : data(data), height(height), width(width) {} | ||||
|  | ||||
|     __device__ __forceinline__ int item(int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x07) * 4; | ||||
|         return (data[row * width / 8 + column / 8] >> shift) & 0x0f; | ||||
|     } | ||||
|   __device__ __forceinline__ int item(int row, int column) const { | ||||
|     int shift = (column & 0x07) * 4; | ||||
|     return (data[row * width / 8 + column / 8] >> shift) & 0x0f; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x07) * 4; | ||||
|         uint32_t d = data[row * width / 8 + column / 8] >> shift; | ||||
|         items[0] = d & 0x0f; | ||||
|         items[1] = (d >> 4) & 0x0f; | ||||
|     } | ||||
|   __device__ __forceinline__ void item2(int (&items)[2], int row, | ||||
|                                         int column) const { | ||||
|     int shift = (column & 0x07) * 4; | ||||
|     uint32_t d = data[row * width / 8 + column / 8] >> shift; | ||||
|     items[0] = d & 0x0f; | ||||
|     items[1] = (d >> 4) & 0x0f; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x07) * 4; | ||||
|         uint32_t d = data[row * width / 8 + column / 8] >> shift; | ||||
|         items[0] = d & 0x0f; | ||||
|         items[1] = (d >> 4) & 0x0f; | ||||
|         items[2] = (d >> 8) & 0x0f; | ||||
|         items[3] = (d >> 12) & 0x0f; | ||||
|     } | ||||
|   __device__ __forceinline__ void item4(int (&items)[4], int row, | ||||
|                                         int column) const { | ||||
|     int shift = (column & 0x07) * 4; | ||||
|     uint32_t d = data[row * width / 8 + column / 8] >> shift; | ||||
|     items[0] = d & 0x0f; | ||||
|     items[1] = (d >> 4) & 0x0f; | ||||
|     items[2] = (d >> 8) & 0x0f; | ||||
|     items[3] = (d >> 12) & 0x0f; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| class MatrixView_q4_column | ||||
| { | ||||
| public: | ||||
|     const uint32_t* data; | ||||
|     const int height; | ||||
|     const int width; | ||||
| class MatrixView_q4_column { | ||||
|  public: | ||||
|   const uint32_t* data; | ||||
|   const int height; | ||||
|   const int width; | ||||
|  | ||||
|     __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) | ||||
|         : data(data), height(height), width(width) | ||||
|     { } | ||||
|   __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, | ||||
|                                                   const int height, | ||||
|                                                   const int width) | ||||
|       : data(data), height(height), width(width) {} | ||||
|  | ||||
|     __device__ __forceinline__ int item(int row, int column) const | ||||
|     { | ||||
|         int shift = (row & 0x07) * 4; | ||||
|         return (data[row / 8 * width + column] >> shift) & 0x0f; | ||||
|     } | ||||
|   __device__ __forceinline__ int item(int row, int column) const { | ||||
|     int shift = (row & 0x07) * 4; | ||||
|     return (data[row / 8 * width + column] >> shift) & 0x0f; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } | ||||
|     __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } | ||||
|   __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { | ||||
|     return data[row / 8 * width + column]; | ||||
|   } | ||||
|   __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, | ||||
|                                                              int column) { | ||||
|     return &data[row / 8 * width + column]; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| class MatrixView_q2_row | ||||
| { | ||||
| public: | ||||
|     const uint32_t* data; | ||||
|     const int height; | ||||
|     const int width; | ||||
| class MatrixView_q2_row { | ||||
|  public: | ||||
|   const uint32_t* data; | ||||
|   const int height; | ||||
|   const int width; | ||||
|  | ||||
|     __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) | ||||
|         : data(data), height(height), width(width) | ||||
|     { } | ||||
|   __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, | ||||
|                                                const int height, | ||||
|                                                const int width) | ||||
|       : data(data), height(height), width(width) {} | ||||
|  | ||||
|     __device__ __forceinline__ int item(int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x0f) * 2; | ||||
|         return (data[row * width / 16 + column / 16] >> shift) & 0x03; | ||||
|     } | ||||
|   __device__ __forceinline__ int item(int row, int column) const { | ||||
|     int shift = (column & 0x0f) * 2; | ||||
|     return (data[row * width / 16 + column / 16] >> shift) & 0x03; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x0f) * 2; | ||||
|         uint32_t d = data[row * width / 16 + column / 16] >> shift; | ||||
|         items[0] = d & 0x03; | ||||
|         items[1] = (d >> 2) & 0x03; | ||||
|     } | ||||
|   __device__ __forceinline__ void item2(int (&items)[2], int row, | ||||
|                                         int column) const { | ||||
|     int shift = (column & 0x0f) * 2; | ||||
|     uint32_t d = data[row * width / 16 + column / 16] >> shift; | ||||
|     items[0] = d & 0x03; | ||||
|     items[1] = (d >> 2) & 0x03; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x0f) * 2; | ||||
|         uint32_t d = data[row * width / 16 + column / 16] >> shift; | ||||
|         items[0] = d & 0x03; | ||||
|         items[1] = (d >> 2) & 0x03; | ||||
|         items[2] = (d >> 4) & 0x03; | ||||
|         items[3] = (d >> 6) & 0x03; | ||||
|     } | ||||
|   __device__ __forceinline__ void item4(int (&items)[4], int row, | ||||
|                                         int column) const { | ||||
|     int shift = (column & 0x0f) * 2; | ||||
|     uint32_t d = data[row * width / 16 + column / 16] >> shift; | ||||
|     items[0] = d & 0x03; | ||||
|     items[1] = (d >> 2) & 0x03; | ||||
|     items[2] = (d >> 4) & 0x03; | ||||
|     items[3] = (d >> 6) & 0x03; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| class MatrixView_q3_row | ||||
| { | ||||
| public: | ||||
|     const uint32_t* data; | ||||
|     const int height; | ||||
|     const int width; | ||||
| class MatrixView_q3_row { | ||||
|  public: | ||||
|   const uint32_t* data; | ||||
|   const int height; | ||||
|   const int width; | ||||
|  | ||||
|     __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) | ||||
|         : data(data), height(height), width(width) | ||||
|     { } | ||||
|   __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, | ||||
|                                                const int height, | ||||
|                                                const int width) | ||||
|       : data(data), height(height), width(width) {} | ||||
|  | ||||
|     __device__ __forceinline__ int item(int row, int column) const | ||||
|     { | ||||
|         int z_w = column * 3 / 32; | ||||
|         int z_mod =  column & 0x1f; | ||||
|   __device__ __forceinline__ int item(int row, int column) const { | ||||
|     int z_w = column * 3 / 32; | ||||
|     int z_mod = column & 0x1f; | ||||
|  | ||||
|         if (z_mod == 10) { | ||||
|             return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); | ||||
|         } else if (z_mod == 21) { | ||||
|             return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); | ||||
|         } else if (z_mod < 10) { | ||||
|             return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; | ||||
|         } else if (z_mod < 21) { | ||||
|             return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3  - 32)) & 0x07; | ||||
|         } else { | ||||
|             return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3  - 64)) & 0x07; | ||||
|         } | ||||
|     if (z_mod == 10) { | ||||
|       return (data[row * width * 3 / 32 + z_w] >> 30) | | ||||
|              ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); | ||||
|     } else if (z_mod == 21) { | ||||
|       return (data[row * width * 3 / 32 + z_w] >> 31) | | ||||
|              ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); | ||||
|     } else if (z_mod < 10) { | ||||
|       return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; | ||||
|     } else if (z_mod < 21) { | ||||
|       return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; | ||||
|     } else { | ||||
|       return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x1f); | ||||
|         uint32_t d; | ||||
|         if (shift <= 4) { | ||||
|             d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); | ||||
|         } else if (shift == 8) { | ||||
|             d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); | ||||
|         } else if (shift <= 16) { | ||||
|             d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); | ||||
|         } else if (shift == 20) { | ||||
|             d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); | ||||
|         } else { | ||||
|             d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); | ||||
|         } | ||||
|         items[0] = d & 0x07; | ||||
|         items[1] = (d >> 3) & 0x07; | ||||
|         items[2] = (d >> 6) & 0x07; | ||||
|         items[3] = (d >> 9) & 0x07; | ||||
|   __device__ __forceinline__ void item4(int (&items)[4], int row, | ||||
|                                         int column) const { | ||||
|     int shift = (column & 0x1f); | ||||
|     uint32_t d; | ||||
|     if (shift <= 4) { | ||||
|       d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); | ||||
|     } else if (shift == 8) { | ||||
|       d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | | ||||
|           ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); | ||||
|     } else if (shift <= 16) { | ||||
|       d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); | ||||
|     } else if (shift == 20) { | ||||
|       d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | | ||||
|           ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); | ||||
|     } else { | ||||
|       d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); | ||||
|     } | ||||
|     items[0] = d & 0x07; | ||||
|     items[1] = (d >> 3) & 0x07; | ||||
|     items[2] = (d >> 6) & 0x07; | ||||
|     items[3] = (d >> 9) & 0x07; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| class MatrixView_q8_row | ||||
| { | ||||
| public: | ||||
|     const uint32_t* data; | ||||
|     const int height; | ||||
|     const int width; | ||||
| class MatrixView_q8_row { | ||||
|  public: | ||||
|   const uint32_t* data; | ||||
|   const int height; | ||||
|   const int width; | ||||
|  | ||||
|     __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) | ||||
|         : data(data), height(height), width(width) | ||||
|     { } | ||||
|   __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, | ||||
|                                                const int height, | ||||
|                                                const int width) | ||||
|       : data(data), height(height), width(width) {} | ||||
|  | ||||
|     __device__ __forceinline__ int item(int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x03) * 8; | ||||
|         return (data[row * width / 4 + column / 4] >> shift) & 0xff; | ||||
|     } | ||||
|   __device__ __forceinline__ int item(int row, int column) const { | ||||
|     int shift = (column & 0x03) * 8; | ||||
|     return (data[row * width / 4 + column / 4] >> shift) & 0xff; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x03) * 8; | ||||
|         uint32_t d = data[row * width / 4 + column / 4] >> shift; | ||||
|         items[0] = d & 0xff; | ||||
|         items[1] = (d >> 8) & 0xff; | ||||
|     } | ||||
|   __device__ __forceinline__ void item2(int (&items)[2], int row, | ||||
|                                         int column) const { | ||||
|     int shift = (column & 0x03) * 8; | ||||
|     uint32_t d = data[row * width / 4 + column / 4] >> shift; | ||||
|     items[0] = d & 0xff; | ||||
|     items[1] = (d >> 8) & 0xff; | ||||
|   } | ||||
|  | ||||
|     __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const | ||||
|     { | ||||
|         int shift = (column & 0x03) * 2; | ||||
|         uint32_t d = data[row * width / 4 + column / 4] >> shift; | ||||
|         items[0] = d & 0xff; | ||||
|         items[1] = (d >> 8) & 0xff; | ||||
|         items[2] = (d >> 16) & 0xff; | ||||
|         items[3] = (d >> 24) & 0xff; | ||||
|     } | ||||
|   __device__ __forceinline__ void item4(int (&items)[4], int row, | ||||
|                                         int column) const { | ||||
|     int shift = (column & 0x03) * 2; | ||||
|     uint32_t d = data[row * width / 4 + column / 4] >> shift; | ||||
|     items[0] = d & 0xff; | ||||
|     items[1] = (d >> 8) & 0xff; | ||||
|     items[2] = (d >> 16) & 0xff; | ||||
|     items[3] = (d >> 24) & 0xff; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| }  // namespace gptq | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -14,71 +14,60 @@ namespace gptq { | ||||
| // | ||||
| // ffddbb99 77553311  eeccaa88 66442200 | ||||
|  | ||||
| __forceinline__ __device__ void shuffle_2bit_16 | ||||
| ( | ||||
|     uint32_t* q, | ||||
|     int stride | ||||
| ) | ||||
| { | ||||
|     uint32_t qa = q[0]; | ||||
|     uint32_t qb = 0; | ||||
| __forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { | ||||
|   uint32_t qa = q[0]; | ||||
|   uint32_t qb = 0; | ||||
|  | ||||
|     #pragma unroll | ||||
|     for (int i = 0; i < 8; i++) | ||||
|     { | ||||
|         uint32_t qa0 = qa & 0x03; | ||||
|         uint32_t qa1 = (qa & 0x0c) >> 2; | ||||
|         qa >>= 4; | ||||
|         qb |= (qa1 << (i * 2 + 16)); | ||||
|         qb |= (qa0 << (i * 2)); | ||||
|     } | ||||
|     q[0] = qb; | ||||
| #pragma unroll | ||||
|   for (int i = 0; i < 8; i++) { | ||||
|     uint32_t qa0 = qa & 0x03; | ||||
|     uint32_t qa1 = (qa & 0x0c) >> 2; | ||||
|     qa >>= 4; | ||||
|     qb |= (qa1 << (i * 2 + 16)); | ||||
|     qb |= (qa0 << (i * 2)); | ||||
|   } | ||||
|   q[0] = qb; | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ void dequant_2bit_16 | ||||
| ( | ||||
|     const uint32_t q_0, | ||||
|     half2 (&dq)[8], | ||||
|     int stride, | ||||
|     const uint32_t zero | ||||
| ) | ||||
| { | ||||
|     const uint32_t c0 = 0x64006400; | ||||
|     const half y4_  = __float2half_rn(1.0f /  4.0f); | ||||
|     const half y16_ = __float2half_rn(1.0f / 16.0f); | ||||
|     const half y64_ = __float2half_rn(1.0f / 64.0f); | ||||
|     const half2 y4  = __halves2half2(y4_,  y4_); | ||||
|     const half2 y16 = __halves2half2(y16_, y16_); | ||||
|     const half2 y64 = __halves2half2(y64_, y64_); | ||||
| __forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, | ||||
|                                                 half2 (&dq)[8], int stride, | ||||
|                                                 const uint32_t zero) { | ||||
|   const uint32_t c0 = 0x64006400; | ||||
|   const half y4_ = __float2half_rn(1.0f / 4.0f); | ||||
|   const half y16_ = __float2half_rn(1.0f / 16.0f); | ||||
|   const half y64_ = __float2half_rn(1.0f / 64.0f); | ||||
|   const half2 y4 = __halves2half2(y4_, y4_); | ||||
|   const half2 y16 = __halves2half2(y16_, y16_); | ||||
|   const half2 y64 = __halves2half2(y64_, y64_); | ||||
|  | ||||
|     const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); | ||||
|     const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); | ||||
|     const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
|     const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); | ||||
|     const half2 z1 = __half2half2(z1_.as_half); | ||||
|     const half2 z4 = __half2half2(z4_); | ||||
|     const half2 z16 = __half2half2(z16_); | ||||
|     const half2 z64 = __half2half2(z64_); | ||||
|   const half_uint16 z1_(0xe400 | zero);  // half(-1024.0f - zero); | ||||
|   const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); | ||||
|   const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
|   const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); | ||||
|   const half2 z1 = __half2half2(z1_.as_half); | ||||
|   const half2 z4 = __half2half2(z4_); | ||||
|   const half2 z16 = __half2half2(z16_); | ||||
|   const half2 z64 = __half2half2(z64_); | ||||
|  | ||||
|     uint32_t qa = q_0; | ||||
|     half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1])      + 1024 | ||||
|     half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) *  4 + 1024 | ||||
|     half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 | ||||
|     half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 | ||||
|     qa >>= 8; | ||||
|     half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8])      + 1024 | ||||
|     half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) *  4 + 1024 | ||||
|     half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 | ||||
|     half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 | ||||
|   uint32_t qa = q_0; | ||||
|   half2_uint32 q0((qa & 0x00030003) | c0);  // half2(q[ 0], q[ 1])      + 1024 | ||||
|   half2_uint32 q1((qa & 0x000c000c) | c0);  // half2(q[ 2], q[ 3]) *  4 + 1024 | ||||
|   half2_uint32 q2((qa & 0x00300030) | c0);  // half2(q[ 4], q[ 5]) * 16 + 1024 | ||||
|   half2_uint32 q3((qa & 0x00c000c0) | c0);  // half2(q[ 6], q[ 7]) * 64 + 1024 | ||||
|   qa >>= 8; | ||||
|   half2_uint32 q4((qa & 0x00030003) | c0);  // half2(q[ 8], q[ 8])      + 1024 | ||||
|   half2_uint32 q5((qa & 0x000c000c) | c0);  // half2(q[10], q[11]) *  4 + 1024 | ||||
|   half2_uint32 q6((qa & 0x00300030) | c0);  // half2(q[12], q[13]) * 16 + 1024 | ||||
|   half2_uint32 q7((qa & 0x00c000c0) | c0);  // half2(q[14], q[15]) * 64 + 1024 | ||||
|  | ||||
|     dq[0] = __hadd2(q0.as_half2, z1); | ||||
|     dq[1] = __hfma2(q1.as_half2, y4,  z4); | ||||
|     dq[2] = __hfma2(q2.as_half2, y16, z16); | ||||
|     dq[3] = __hfma2(q3.as_half2, y64, z64); | ||||
|     dq[4] = __hadd2(q4.as_half2, z1); | ||||
|     dq[5] = __hfma2(q5.as_half2, y4,  z4); | ||||
|     dq[6] = __hfma2(q6.as_half2, y16, z16); | ||||
|     dq[7] = __hfma2(q7.as_half2, y64, z64); | ||||
|   dq[0] = __hadd2(q0.as_half2, z1); | ||||
|   dq[1] = __hfma2(q1.as_half2, y4, z4); | ||||
|   dq[2] = __hfma2(q2.as_half2, y16, z16); | ||||
|   dq[3] = __hfma2(q3.as_half2, y64, z64); | ||||
|   dq[4] = __hadd2(q4.as_half2, z1); | ||||
|   dq[5] = __hfma2(q5.as_half2, y4, z4); | ||||
|   dq[6] = __hfma2(q6.as_half2, y16, z16); | ||||
|   dq[7] = __hfma2(q7.as_half2, y64, z64); | ||||
| } | ||||
|  | ||||
| }  // namespace gptq | ||||
|  | ||||
| @ -11,128 +11,136 @@ namespace gptq { | ||||
| // vjjjhhhf ffdddbbb  uiiiggge eecccaaa | ||||
| // vtttrrrp ppnnnlll  usssqqqo oommmkkk | ||||
|  | ||||
| __forceinline__ __device__ void shuffle_3bit_32 | ||||
| ( | ||||
|     uint32_t* q, | ||||
|     int stride | ||||
| ) | ||||
| { | ||||
|     uint32_t qa = q[0 * stride]; | ||||
|     uint32_t qb = q[1 * stride]; | ||||
|     uint32_t qc = q[2 * stride]; | ||||
| __forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { | ||||
|   uint32_t qa = q[0 * stride]; | ||||
|   uint32_t qb = q[1 * stride]; | ||||
|   uint32_t qc = q[2 * stride]; | ||||
|  | ||||
|     // qa: aa999888 77766655  54443332 22111000 | ||||
|     // qb: lkkkjjji iihhhggg  fffeeedd dcccbbba | ||||
|     // qc: vvvuuutt tsssrrrq  qqpppooo nnnmmmll | ||||
|   // qa: aa999888 77766655  54443332 22111000 | ||||
|   // qb: lkkkjjji iihhhggg  fffeeedd dcccbbba | ||||
|   // qc: vvvuuutt tsssrrrq  qqpppooo nnnmmmll | ||||
|  | ||||
|     uint32_t qd = qc >> 26; | ||||
|     qc <<= 4; | ||||
|     qc |= qb >> 28; | ||||
|     qb <<= 2; | ||||
|     qb |= qa >> 30; | ||||
|   uint32_t qd = qc >> 26; | ||||
|   qc <<= 4; | ||||
|   qc |= qb >> 28; | ||||
|   qb <<= 2; | ||||
|   qb |= qa >> 30; | ||||
|  | ||||
|     // qa: ..999888 77766655  54443332 22111000 | ||||
|     // qb: ..jjjiii hhhgggff  feeedddc ccbbbaaa | ||||
|     // qc: ..tttsss rrrqqqpp  pooonnnm mmlllkkk | ||||
|     // qd:                               vvvuuu | ||||
|   // qa: ..999888 77766655  54443332 22111000 | ||||
|   // qb: ..jjjiii hhhgggff  feeedddc ccbbbaaa | ||||
|   // qc: ..tttsss rrrqqqpp  pooonnnm mmlllkkk | ||||
|   // qd:                               vvvuuu | ||||
|  | ||||
|     uint32_t za = 0; | ||||
|     uint32_t zb = 0; | ||||
|     uint32_t zc = 0; | ||||
|   uint32_t za = 0; | ||||
|   uint32_t zb = 0; | ||||
|   uint32_t zc = 0; | ||||
|  | ||||
|     for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } | ||||
|     for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } | ||||
|     for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } | ||||
|   for (int i = 0; i < 5; i++) { | ||||
|     uint32_t t0 = qa & 0x07; | ||||
|     uint32_t t1 = (qa & 0x38) >> 3; | ||||
|     qa >>= 6; | ||||
|     za |= (t0 << (i * 3)); | ||||
|     za |= (t1 << (i * 3 + 16)); | ||||
|   } | ||||
|   for (int i = 0; i < 5; i++) { | ||||
|     uint32_t t0 = qb & 0x07; | ||||
|     uint32_t t1 = (qb & 0x38) >> 3; | ||||
|     qb >>= 6; | ||||
|     zb |= (t0 << (i * 3)); | ||||
|     zb |= (t1 << (i * 3 + 16)); | ||||
|   } | ||||
|   for (int i = 0; i < 5; i++) { | ||||
|     uint32_t t0 = qc & 0x07; | ||||
|     uint32_t t1 = (qc & 0x38) >> 3; | ||||
|     qc >>= 6; | ||||
|     zc |= (t0 << (i * 3)); | ||||
|     zc |= (t1 << (i * 3 + 16)); | ||||
|   } | ||||
|  | ||||
|     // za:  9997775 55333111   8886664 44222000 | ||||
|     // zb:  jjjhhhf ffdddbbb   iiiggge eecccaaa | ||||
|     // zc:  tttrrrp ppnnnlll   sssqqqo oommmkkk | ||||
|     // qd:                               vvvuuu | ||||
|   // za:  9997775 55333111   8886664 44222000 | ||||
|   // zb:  jjjhhhf ffdddbbb   iiiggge eecccaaa | ||||
|   // zc:  tttrrrp ppnnnlll   sssqqqo oommmkkk | ||||
|   // qd:                               vvvuuu | ||||
|  | ||||
|     za |= ((qd & 0x01) >> 0) << 15; | ||||
|     zb |= ((qd & 0x02) >> 1) << 15; | ||||
|     zc |= ((qd & 0x04) >> 2) << 15; | ||||
|     za |= ((qd & 0x08) >> 3) << 31; | ||||
|     zb |= ((qd & 0x10) >> 4) << 31; | ||||
|     zc |= ((qd & 0x20) >> 5) << 31; | ||||
|   za |= ((qd & 0x01) >> 0) << 15; | ||||
|   zb |= ((qd & 0x02) >> 1) << 15; | ||||
|   zc |= ((qd & 0x04) >> 2) << 15; | ||||
|   za |= ((qd & 0x08) >> 3) << 31; | ||||
|   zb |= ((qd & 0x10) >> 4) << 31; | ||||
|   zc |= ((qd & 0x20) >> 5) << 31; | ||||
|  | ||||
|     // za: v9997775 55333111  u8886664 44222000  (u, v lsb) | ||||
|     // zb: vjjjhhhf ffdddbbb  uiiiggge eecccaaa | ||||
|     // zc: vtttrrrp ppnnnlll  usssqqqo oommmkkk | ||||
|   // za: v9997775 55333111  u8886664 44222000  (u, v lsb) | ||||
|   // zb: vjjjhhhf ffdddbbb  uiiiggge eecccaaa | ||||
|   // zc: vtttrrrp ppnnnlll  usssqqqo oommmkkk | ||||
|  | ||||
|     q[0 * stride] = za; | ||||
|     q[1 * stride] = zb; | ||||
|     q[2 * stride] = zc; | ||||
|   q[0 * stride] = za; | ||||
|   q[1 * stride] = zb; | ||||
|   q[2 * stride] = zc; | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ void dequant_3bit_32 | ||||
| ( | ||||
|     const uint32_t q_0, | ||||
|     const uint32_t q_1, | ||||
|     const uint32_t q_2, | ||||
|     half2 (&dq)[16], | ||||
|     int stride, | ||||
|     const uint32_t zero | ||||
| ) | ||||
| { | ||||
|     const uint32_t c0 = 0x64006400; | ||||
|     const half y8_  = __float2half_rn(1.0f /  8.0f); | ||||
|     const half y64_ = __float2half_rn(1.0f / 64.0f); | ||||
|     const half2 y8  = __halves2half2(y8_,  y8_); | ||||
|     const half2 y64 = __halves2half2(y64_, y64_); | ||||
|     const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); | ||||
|     const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); | ||||
|     const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); | ||||
|     const half2 z1  = __halves2half2(z1_.as_half,  z1_.as_half); | ||||
|     const half2 z8  = __halves2half2(z8_,  z8_); | ||||
|     const half2 z64 = __halves2half2(z64_, z64_); | ||||
| __forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, | ||||
|                                                 const uint32_t q_1, | ||||
|                                                 const uint32_t q_2, | ||||
|                                                 half2 (&dq)[16], int stride, | ||||
|                                                 const uint32_t zero) { | ||||
|   const uint32_t c0 = 0x64006400; | ||||
|   const half y8_ = __float2half_rn(1.0f / 8.0f); | ||||
|   const half y64_ = __float2half_rn(1.0f / 64.0f); | ||||
|   const half2 y8 = __halves2half2(y8_, y8_); | ||||
|   const half2 y64 = __halves2half2(y64_, y64_); | ||||
|   const half_uint16 z1_(0xe400 | zero);  // half(-1024.0f - zero); | ||||
|   const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); | ||||
|   const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); | ||||
|   const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); | ||||
|   const half2 z8 = __halves2half2(z8_, z8_); | ||||
|   const half2 z64 = __halves2half2(z64_, z64_); | ||||
|  | ||||
|     uint32_t qa = q_0; | ||||
|     uint32_t qb = q_1; | ||||
|     uint32_t qc = q_2; | ||||
|   uint32_t qa = q_0; | ||||
|   uint32_t qb = q_1; | ||||
|   uint32_t qc = q_2; | ||||
|  | ||||
|     half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1])      + 1024 | ||||
|     half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) *  8 + 1024 | ||||
|     qa >>= 6; | ||||
|     half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5])      + 1024 | ||||
|     half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) *  8 + 1024 | ||||
|     half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 | ||||
|     qa >>= 9; | ||||
|     qa &= 0x00010001; | ||||
|     half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11])      + 1024 | ||||
|     half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) *  8 + 1024 | ||||
|     qb >>= 6; | ||||
|     half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15])      + 1024 | ||||
|     half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) *  8 + 1024 | ||||
|     half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 | ||||
|     qb >>= 8; | ||||
|     qb &= 0x00020002; | ||||
|     half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21])      + 1024 | ||||
|     half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) *  8 + 1024 | ||||
|     qc >>= 6; | ||||
|     half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25])      + 1024 | ||||
|     half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) *  8 + 1024 | ||||
|     half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 | ||||
|     qc >>= 7; | ||||
|     qc &= 0x00040004; | ||||
|     half2_uint32 q15((qa | qb | qc) | c0); | ||||
|   half2_uint32 q0((qa & 0x00070007) | c0);  // half2(q[ 0], q[ 1])      + 1024 | ||||
|   half2_uint32 q1((qa & 0x00380038) | c0);  // half2(q[ 2], q[ 3]) *  8 + 1024 | ||||
|   qa >>= 6; | ||||
|   half2_uint32 q2((qa & 0x00070007) | c0);  // half2(q[ 4], q[ 5])      + 1024 | ||||
|   half2_uint32 q3((qa & 0x00380038) | c0);  // half2(q[ 6], q[ 7]) *  8 + 1024 | ||||
|   half2_uint32 q4((qa & 0x01c001c0) | c0);  // half2(q[ 8], q[ 9]) * 64 + 1024 | ||||
|   qa >>= 9; | ||||
|   qa &= 0x00010001; | ||||
|   half2_uint32 q5((qb & 0x00070007) | c0);  // half2(q[10], q[11])      + 1024 | ||||
|   half2_uint32 q6((qb & 0x00380038) | c0);  // half2(q[12], q[13]) *  8 + 1024 | ||||
|   qb >>= 6; | ||||
|   half2_uint32 q7((qb & 0x00070007) | c0);  // half2(q[14], q[15])      + 1024 | ||||
|   half2_uint32 q8((qb & 0x00380038) | c0);  // half2(q[16], q[17]) *  8 + 1024 | ||||
|   half2_uint32 q9((qb & 0x01c001c0) | c0);  // half2(q[18], q[19]) * 64 + 1024 | ||||
|   qb >>= 8; | ||||
|   qb &= 0x00020002; | ||||
|   half2_uint32 q10((qc & 0x00070007) | c0);  // half2(q[20], q[21])      + 1024 | ||||
|   half2_uint32 q11((qc & 0x00380038) | c0);  // half2(q[22], q[23]) *  8 + 1024 | ||||
|   qc >>= 6; | ||||
|   half2_uint32 q12((qc & 0x00070007) | c0);  // half2(q[24], q[25])      + 1024 | ||||
|   half2_uint32 q13((qc & 0x00380038) | c0);  // half2(q[26], q[27]) *  8 + 1024 | ||||
|   half2_uint32 q14((qc & 0x01c001c0) | c0);  // half2(q[28], q[29]) * 64 + 1024 | ||||
|   qc >>= 7; | ||||
|   qc &= 0x00040004; | ||||
|   half2_uint32 q15((qa | qb | qc) | c0); | ||||
|  | ||||
|     dq[ 0] = __hadd2( q0.as_half2, z1); | ||||
|     dq[ 1] = __hfma2( q1.as_half2, y8,  z8); | ||||
|     dq[ 2] = __hadd2( q2.as_half2, z1); | ||||
|     dq[ 3] = __hfma2( q3.as_half2, y8,  z8); | ||||
|     dq[ 4] = __hfma2( q4.as_half2, y64, z64); | ||||
|     dq[ 5] = __hadd2( q5.as_half2, z1); | ||||
|     dq[ 6] = __hfma2( q6.as_half2, y8,  z8); | ||||
|     dq[ 7] = __hadd2( q7.as_half2, z1); | ||||
|     dq[ 8] = __hfma2( q8.as_half2, y8,  z8); | ||||
|     dq[ 9] = __hfma2( q9.as_half2, y64, z64); | ||||
|     dq[10] = __hadd2(q10.as_half2, z1); | ||||
|     dq[11] = __hfma2(q11.as_half2, y8,  z8); | ||||
|     dq[12] = __hadd2(q12.as_half2, z1); | ||||
|     dq[13] = __hfma2(q13.as_half2, y8,  z8); | ||||
|     dq[14] = __hfma2(q14.as_half2, y64, z64); | ||||
|     dq[15] = __hadd2(q15.as_half2, z1); | ||||
|   dq[0] = __hadd2(q0.as_half2, z1); | ||||
|   dq[1] = __hfma2(q1.as_half2, y8, z8); | ||||
|   dq[2] = __hadd2(q2.as_half2, z1); | ||||
|   dq[3] = __hfma2(q3.as_half2, y8, z8); | ||||
|   dq[4] = __hfma2(q4.as_half2, y64, z64); | ||||
|   dq[5] = __hadd2(q5.as_half2, z1); | ||||
|   dq[6] = __hfma2(q6.as_half2, y8, z8); | ||||
|   dq[7] = __hadd2(q7.as_half2, z1); | ||||
|   dq[8] = __hfma2(q8.as_half2, y8, z8); | ||||
|   dq[9] = __hfma2(q9.as_half2, y64, z64); | ||||
|   dq[10] = __hadd2(q10.as_half2, z1); | ||||
|   dq[11] = __hfma2(q11.as_half2, y8, z8); | ||||
|   dq[12] = __hadd2(q12.as_half2, z1); | ||||
|   dq[13] = __hfma2(q13.as_half2, y8, z8); | ||||
|   dq[14] = __hfma2(q14.as_half2, y64, z64); | ||||
|   dq[15] = __hadd2(q15.as_half2, z1); | ||||
| } | ||||
|  | ||||
| }  // namespace gptq | ||||
|  | ||||
| @ -13,133 +13,112 @@ namespace gptq { | ||||
| // | ||||
| // 77775555 33331111  66664444 22220000 | ||||
|  | ||||
| __forceinline__ __device__ void shuffle_4bit_8 | ||||
| ( | ||||
|     uint32_t* q, | ||||
|     int stride | ||||
| ) | ||||
| { | ||||
|     uint32_t qa = q[0]; | ||||
|     uint32_t qb = 0; | ||||
| __forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { | ||||
|   uint32_t qa = q[0]; | ||||
|   uint32_t qb = 0; | ||||
|  | ||||
|     #pragma unroll | ||||
|     for (int i = 0; i < 4; i++) | ||||
|     { | ||||
|         uint32_t qa0 = qa & 0x0f; | ||||
|         uint32_t qa1 = (qa & 0xf0) >> 4; | ||||
|         qa >>= 8; | ||||
|         qb |= (qa1 << (i * 4 + 16)); | ||||
|         qb |= (qa0 << (i * 4)); | ||||
|     } | ||||
|     q[0] = qb; | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ void dequant_4bit_8 | ||||
| ( | ||||
|     const uint32_t q_0, | ||||
|     half2 (&dq)[4], | ||||
|     int stride, | ||||
|     const uint32_t zero | ||||
| ) | ||||
| { | ||||
|     const uint32_t c0 = 0x64006400; | ||||
|     const half y16_ = __float2half_rn(1.0f / 16.0f); | ||||
|     const half2 y16 = __halves2half2(y16_, y16_); | ||||
|     const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); | ||||
|     const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
|     const half2 z1 = __half2half2(z1_.as_half); | ||||
|     const half2 z16 = __half2half2(z16_); | ||||
|  | ||||
|     uint32_t qa = q_0; | ||||
|     half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1])      + 1024 | ||||
|     half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 | ||||
| #pragma unroll | ||||
|   for (int i = 0; i < 4; i++) { | ||||
|     uint32_t qa0 = qa & 0x0f; | ||||
|     uint32_t qa1 = (qa & 0xf0) >> 4; | ||||
|     qa >>= 8; | ||||
|     half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5])      + 1024 | ||||
|     half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 | ||||
|  | ||||
|     dq[0] = __hadd2(q0.as_half2, z1); | ||||
|     dq[1] = __hfma2(q1.as_half2, y16, z16); | ||||
|     dq[2] = __hadd2(q2.as_half2, z1); | ||||
|     dq[3] = __hfma2(q3.as_half2, y16, z16); | ||||
|     qb |= (qa1 << (i * 4 + 16)); | ||||
|     qb |= (qa0 << (i * 4)); | ||||
|   } | ||||
|   q[0] = qb; | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale | ||||
| ( | ||||
|     const uint32_t zero, | ||||
|     const half scale, | ||||
|     half2 (&z1z16)[2], | ||||
|     half2 (&y1y16)[2] | ||||
| ) | ||||
| { | ||||
|     half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); | ||||
|     half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
| __forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, | ||||
|                                                half2 (&dq)[4], int stride, | ||||
|                                                const uint32_t zero) { | ||||
|   const uint32_t c0 = 0x64006400; | ||||
|   const half y16_ = __float2half_rn(1.0f / 16.0f); | ||||
|   const half2 y16 = __halves2half2(y16_, y16_); | ||||
|   const half_uint16 z1_(0xe400 | zero);  // half(-1024.0f - zero); | ||||
|   const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
|   const half2 z1 = __half2half2(z1_.as_half); | ||||
|   const half2 z16 = __half2half2(z16_); | ||||
|  | ||||
|     half2 scale2 = __half2half2(scale); | ||||
|   uint32_t qa = q_0; | ||||
|   half2_uint32 q0((qa & 0x000f000f) | c0);  // half2(q[ 0], q[ 1])      + 1024 | ||||
|   half2_uint32 q1((qa & 0x00f000f0) | c0);  // half2(q[ 2], q[ 3]) * 16 + 1024 | ||||
|   qa >>= 8; | ||||
|   half2_uint32 q2((qa & 0x000f000f) | c0);  // half2(q[ 4], q[ 5])      + 1024 | ||||
|   half2_uint32 q3((qa & 0x00f000f0) | c0);  // half2(q[ 6], q[ 7]) * 16 + 1024 | ||||
|  | ||||
|     z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); | ||||
|     z1z16[1] = __hmul2(scale2, __half2half2(z16)); | ||||
|  | ||||
|     const half y1 = __float2half_rn(1.0f); | ||||
|     const half y16 = __float2half_rn(1.0f / 16.0f); | ||||
|  | ||||
|     y1y16[0] = __hmul2(scale2, __half2half2(y1)); | ||||
|     y1y16[1] = __hmul2(scale2, __half2half2(y16)); | ||||
|   dq[0] = __hadd2(q0.as_half2, z1); | ||||
|   dq[1] = __hfma2(q1.as_half2, y16, z16); | ||||
|   dq[2] = __hadd2(q2.as_half2, z1); | ||||
|   dq[3] = __hfma2(q3.as_half2, y16, z16); | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ void dequant_4bit_8_prep_zero | ||||
| ( | ||||
|     const uint32_t zero, | ||||
|     half2(&z1z16)[2], | ||||
|     half2(&y1y16)[2] | ||||
| ) | ||||
| { | ||||
|     half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); | ||||
|     half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
| __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( | ||||
|     const uint32_t zero, const half scale, half2 (&z1z16)[2], | ||||
|     half2 (&y1y16)[2]) { | ||||
|   half_uint16 z1(0xe400 | zero);  // half(-1024.0f - zero); | ||||
|   half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
|  | ||||
|     z1z16[0] = __half2half2(z1.as_half); | ||||
|     z1z16[1] = __half2half2(z16); | ||||
|   half2 scale2 = __half2half2(scale); | ||||
|  | ||||
|     const half y1 = __float2half_rn(1.0f); | ||||
|     const half y16 = __float2half_rn(1.0f / 16.0f); | ||||
|   z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); | ||||
|   z1z16[1] = __hmul2(scale2, __half2half2(z16)); | ||||
|  | ||||
|     y1y16[0] = __half2half2(y1); | ||||
|     y1y16[1] = __half2half2(y16); | ||||
|   const half y1 = __float2half_rn(1.0f); | ||||
|   const half y16 = __float2half_rn(1.0f / 16.0f); | ||||
|  | ||||
|   y1y16[0] = __hmul2(scale2, __half2half2(y1)); | ||||
|   y1y16[1] = __hmul2(scale2, __half2half2(y16)); | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, | ||||
|                                                          half2 (&z1z16)[2], | ||||
|                                                          half2 (&y1y16)[2]) { | ||||
|   half_uint16 z1(0xe400 | zero);  // half(-1024.0f - zero); | ||||
|   half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); | ||||
|  | ||||
| __forceinline__ __device__ void dequant_4bit_8_gptq | ||||
| ( | ||||
|     const uint32_t q_0, | ||||
|     half2 (&dq)[4], | ||||
|     half2 (&z1z16)[2], | ||||
|     half2 (&y1y16)[2], | ||||
|     int stride, | ||||
|     bool scaled | ||||
| ) | ||||
| { | ||||
|     const uint32_t c0 = 0x64006400; | ||||
|   z1z16[0] = __half2half2(z1.as_half); | ||||
|   z1z16[1] = __half2half2(z16); | ||||
|  | ||||
|     uint32_t qa = q_0; | ||||
|     half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0]      + 1024, q[1]      + 1024 ) | ||||
|     half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) | ||||
|     qa >>= 8; | ||||
|     half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4]      + 1024, q[5]      + 1024 ) | ||||
|     half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) | ||||
|   const half y1 = __float2half_rn(1.0f); | ||||
|   const half y16 = __float2half_rn(1.0f / 16.0f); | ||||
|  | ||||
|     if (scaled) | ||||
|     { | ||||
|         dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]);  // half2( q[0] * s - z * s, q[1] * s - z * s) | ||||
|         dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]);  // half2( q[2] * s - z * s, q[3] * s - z * s) | ||||
|         dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); | ||||
|         dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); | ||||
|     } | ||||
|     else | ||||
|     { | ||||
|         dq[0] = __hadd2(q0.as_half2,           z1z16[0]);  // half2( q[0] - z, q[1] - z ) | ||||
|         dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]);  // half2( q[2] - z, q[3] - z ) | ||||
|         dq[2] = __hadd2(q2.as_half2,           z1z16[0]);  // half2( q[4] - z, q[5] - z ) | ||||
|         dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);  // half2( q[6] - z, q[7] - z ) | ||||
|     } | ||||
|   y1y16[0] = __half2half2(y1); | ||||
|   y1y16[1] = __half2half2(y16); | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, | ||||
|                                                     half2 (&dq)[4], | ||||
|                                                     half2 (&z1z16)[2], | ||||
|                                                     half2 (&y1y16)[2], | ||||
|                                                     int stride, bool scaled) { | ||||
|   const uint32_t c0 = 0x64006400; | ||||
|  | ||||
|   uint32_t qa = q_0; | ||||
|   half2_uint32 q0((qa & 0x000f000f) | | ||||
|                   c0);  // half2( q[0]      + 1024, q[1]      + 1024 ) | ||||
|   half2_uint32 q1((qa & 0x00f000f0) | | ||||
|                   c0);  // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) | ||||
|   qa >>= 8; | ||||
|   half2_uint32 q2((qa & 0x000f000f) | | ||||
|                   c0);  // half2( q[4]      + 1024, q[5]      + 1024 ) | ||||
|   half2_uint32 q3((qa & 0x00f000f0) | | ||||
|                   c0);  // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) | ||||
|  | ||||
|   if (scaled) { | ||||
|     dq[0] = __hfma2(q0.as_half2, y1y16[0], | ||||
|                     z1z16[0]);  // half2( q[0] * s - z * s, q[1] * s - z * s) | ||||
|     dq[1] = __hfma2(q1.as_half2, y1y16[1], | ||||
|                     z1z16[1]);  // half2( q[2] * s - z * s, q[3] * s - z * s) | ||||
|     dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); | ||||
|     dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); | ||||
|   } else { | ||||
|     dq[0] = __hadd2(q0.as_half2, z1z16[0]);  // half2( q[0] - z, q[1] - z ) | ||||
|     dq[1] = __hfma2(q1.as_half2, y1y16[1], | ||||
|                     z1z16[1]);               // half2( q[2] - z, q[3] - z ) | ||||
|     dq[2] = __hadd2(q2.as_half2, z1z16[0]);  // half2( q[4] - z, q[5] - z ) | ||||
|     dq[3] = __hfma2(q3.as_half2, y1y16[1], | ||||
|                     z1z16[1]);  // half2( q[6] - z, q[7] - z ) | ||||
|   } | ||||
| } | ||||
| }  // namespace gptq | ||||
| }  // namespace vllm | ||||
|  | ||||
| @ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2 | ||||
| namespace vllm { | ||||
| namespace gptq { | ||||
|  | ||||
| __forceinline__ __device__ void shuffle_8bit_4 | ||||
| ( | ||||
|     uint32_t* q, | ||||
|     int stride | ||||
| ) | ||||
| { | ||||
| } | ||||
| __forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} | ||||
|  | ||||
| __forceinline__ __device__ void dequant_8bit_8 | ||||
| ( | ||||
|     const uint32_t q_0, | ||||
|     const uint32_t q_1, | ||||
|     half2 (&dq)[4], | ||||
|     int stride, | ||||
|     const uint32_t zero | ||||
| ) | ||||
| { | ||||
|     half dqh[8]; | ||||
|     for (int i = 0; i < 4; i++) dqh[i    ] = dq_ns(exb(q_0, i * 8, 0xff), zero); | ||||
|     for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); | ||||
| __forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, | ||||
|                                                const uint32_t q_1, | ||||
|                                                half2 (&dq)[4], int stride, | ||||
|                                                const uint32_t zero) { | ||||
|   half dqh[8]; | ||||
|   for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); | ||||
|   for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); | ||||
|  | ||||
|     for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); | ||||
|   for (int i = 0; i < 4; i++) | ||||
|     dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); | ||||
| } | ||||
|  | ||||
| }  // namespace gptq | ||||
|  | ||||
| @ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2 | ||||
| namespace vllm { | ||||
| namespace gptq { | ||||
|  | ||||
| union half2_uint32 | ||||
| { | ||||
|     uint32_t as_uint32; | ||||
|     half2 as_half2; | ||||
|     __device__ half2_uint32(uint32_t val) : as_uint32(val) {} | ||||
|     __device__ half2_uint32(half2 val) : as_half2(val) {} | ||||
| union half2_uint32 { | ||||
|   uint32_t as_uint32; | ||||
|   half2 as_half2; | ||||
|   __device__ half2_uint32(uint32_t val) : as_uint32(val) {} | ||||
|   __device__ half2_uint32(half2 val) : as_half2(val) {} | ||||
| }; | ||||
|  | ||||
| union half_uint16 | ||||
| { | ||||
|     uint16_t as_uint16; | ||||
|     half as_half; | ||||
|     __device__ half_uint16(uint16_t val) : as_uint16(val) {} | ||||
|     __device__ half_uint16(half val) : as_half(val) {} | ||||
| union half_uint16 { | ||||
|   uint16_t as_uint16; | ||||
|   half as_half; | ||||
|   __device__ half_uint16(uint16_t val) : as_uint16(val) {} | ||||
|   __device__ half_uint16(half val) : as_half(val) {} | ||||
| }; | ||||
|  | ||||
| // Max_scale premultiplied by 1/256 | ||||
|  | ||||
| __forceinline__ __device__ half dq_scale(const int qs, const half max_scale) | ||||
| { | ||||
|     int qs_i = qs + 1; | ||||
|     half qs_h = __int2half_rn(qs_i * qs_i); | ||||
|     qs_h = __hmul(qs_h, max_scale); | ||||
|     return qs_h; | ||||
| __forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { | ||||
|   int qs_i = qs + 1; | ||||
|   half qs_h = __int2half_rn(qs_i * qs_i); | ||||
|   qs_h = __hmul(qs_h, max_scale); | ||||
|   return qs_h; | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ half dq(const int q, const int qzero, const half scale) | ||||
| { | ||||
|     return __hmul(__int2half_rn(q - qzero), scale); | ||||
| __forceinline__ __device__ half dq(const int q, const int qzero, | ||||
|                                    const half scale) { | ||||
|   return __hmul(__int2half_rn(q - qzero), scale); | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ half dq_ns(const int q, const int qzero) | ||||
| { | ||||
|     //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); | ||||
|     return __int2half_rn(q - qzero); | ||||
| __forceinline__ __device__ half dq_ns(const int q, const int qzero) { | ||||
|   // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); | ||||
|   return __int2half_rn(q - qzero); | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) | ||||
| { | ||||
|     return (int)((q >> shift) & mask); | ||||
| __forceinline__ __device__ int exb(const uint32_t q, const int shift, | ||||
|                                    const int mask) { | ||||
|   return (int)((q >> shift) & mask); | ||||
| } | ||||
|  | ||||
| __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) | ||||
| { | ||||
|     return (int)(__funnelshift_rc(q0, q1, shift) & mask); | ||||
| __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, | ||||
|                                    const int shift, const int mask) { | ||||
|   return (int)(__funnelshift_rc(q0, q1, shift) & mask); | ||||
| } | ||||
|  | ||||
| }  // namespace gptq | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -11,22 +11,23 @@ | ||||
|  | ||||
| namespace gptq_marlin { | ||||
|  | ||||
| // 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per | ||||
| // schedule allows some more latency hiding. At the same time, we want relatively few warps to have | ||||
| // many registers per warp and small tiles. | ||||
| // 8 warps are a good choice since every SM has 4 schedulers and having more | ||||
| // than 1 warp per schedule allows some more latency hiding. At the same time, | ||||
| // we want relatively few warps to have many registers per warp and small tiles. | ||||
| static constexpr int default_threads = 256; | ||||
|  | ||||
| static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory | ||||
| static constexpr int pipe_stages = | ||||
|     4;  // 4 pipeline stages fit into shared memory | ||||
|  | ||||
| static constexpr int min_thread_n = 64; | ||||
| static constexpr int min_thread_k = 64; | ||||
|  | ||||
| static constexpr int tile_size = 16; | ||||
| static constexpr int max_par   = 16; | ||||
| static constexpr int max_par = 16; | ||||
|  | ||||
| template <typename T, int n> | ||||
| struct Vec { | ||||
|   T             elems[n]; | ||||
|   T elems[n]; | ||||
|   __device__ T& operator[](int i) { return elems[i]; } | ||||
| }; | ||||
|  | ||||
| @ -35,30 +36,35 @@ using I4 = Vec<int, 4>; | ||||
| constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } | ||||
|  | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||
|   // No support for async | ||||
| // No support for async | ||||
| #else | ||||
|  | ||||
| __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { | ||||
| __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, | ||||
|                                       bool pred = true) { | ||||
|   const int BYTES = 16; | ||||
|   uint32_t  smem  = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||||
|   asm volatile("{\n" | ||||
|                "   .reg .pred p;\n" | ||||
|                "   setp.ne.b32 p, %0, 0;\n" | ||||
|                "   @p cp.async.cg.shared.global [%1], [%2], %3;\n" | ||||
|                "}\n" ::"r"((int)pred), | ||||
|                "r"(smem), "l"(glob_ptr), "n"(BYTES)); | ||||
|   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||||
|   asm volatile( | ||||
|       "{\n" | ||||
|       "   .reg .pred p;\n" | ||||
|       "   setp.ne.b32 p, %0, 0;\n" | ||||
|       "   @p cp.async.cg.shared.global [%1], [%2], %3;\n" | ||||
|       "}\n" ::"r"((int)pred), | ||||
|       "r"(smem), "l"(glob_ptr), "n"(BYTES)); | ||||
| } | ||||
|  | ||||
| __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { | ||||
|   const int BYTES = 16; | ||||
|   uint32_t  smem  = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||||
|   asm volatile("{\n" | ||||
|                "   cp.async.cg.shared.global [%0], [%1], %2;\n" | ||||
|                "}\n" ::"r"(smem), | ||||
|                "l"(glob_ptr), "n"(BYTES)); | ||||
|   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||||
|   asm volatile( | ||||
|       "{\n" | ||||
|       "   cp.async.cg.shared.global [%0], [%1], %2;\n" | ||||
|       "}\n" ::"r"(smem), | ||||
|       "l"(glob_ptr), "n"(BYTES)); | ||||
| } | ||||
|  | ||||
| __device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } | ||||
| __device__ inline void cp_async_fence() { | ||||
|   asm volatile("cp.async.commit_group;\n" ::); | ||||
| } | ||||
|  | ||||
| template <int n> | ||||
| __device__ inline void cp_async_wait() { | ||||
| @ -67,4 +73,4 @@ __device__ inline void cp_async_wait() { | ||||
|  | ||||
| #endif | ||||
|  | ||||
| } // namespace gptq_marlin | ||||
| }  // namespace gptq_marlin | ||||
|  | ||||
							
								
								
									
										77
									
								
								csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | ||||
|  | ||||
| #ifndef _data_types_cuh | ||||
| #define _data_types_cuh | ||||
| #include "gptq_marlin.cuh" | ||||
| #include <cuda_fp16.h> | ||||
| #include <cuda_bf16.h> | ||||
|  | ||||
| namespace gptq_marlin { | ||||
|  | ||||
| template <typename scalar_t> | ||||
| class ScalarType {}; | ||||
|  | ||||
| template <> | ||||
| class ScalarType<half> { | ||||
|  public: | ||||
|   using scalar_t = half; | ||||
|   using scalar_t2 = half2; | ||||
|  | ||||
|   // Matrix fragments for tensor core instructions; their precise layout is | ||||
|   // documented here: | ||||
|   // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type | ||||
|   using FragA = Vec<half2, 4>; | ||||
|   using FragB = Vec<half2, 2>; | ||||
|   using FragC = Vec<float, 4>; | ||||
|   using FragS = Vec<half2, 1>; | ||||
|  | ||||
|   static __device__ float inline num2float(const half x) { | ||||
|     return __half2float(x); | ||||
|   } | ||||
|  | ||||
|   static __device__ half2 inline num2num2(const half x) { | ||||
|     return __half2half2(x); | ||||
|   } | ||||
|  | ||||
|   static __device__ half2 inline nums2num2(const half x1, const half x2) { | ||||
|     return __halves2half2(x1, x2); | ||||
|   } | ||||
|  | ||||
|   static __host__ __device__ half inline float2num(const float x) { | ||||
|     return __float2half(x); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| class ScalarType<nv_bfloat16> { | ||||
|  public: | ||||
|   using scalar_t = nv_bfloat16; | ||||
|   using scalar_t2 = nv_bfloat162; | ||||
|  | ||||
|   using FragA = Vec<nv_bfloat162, 4>; | ||||
|   using FragB = Vec<nv_bfloat162, 2>; | ||||
|   using FragC = Vec<float, 4>; | ||||
|   using FragS = Vec<nv_bfloat162, 1>; | ||||
|  | ||||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | ||||
|   static __device__ float inline num2float(const nv_bfloat16 x) { | ||||
|     return __bfloat162float(x); | ||||
|   } | ||||
|  | ||||
|   static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { | ||||
|     return __bfloat162bfloat162(x); | ||||
|   } | ||||
|  | ||||
|   static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, | ||||
|                                                   const nv_bfloat16 x2) { | ||||
|     return __halves2bfloat162(x1, x2); | ||||
|   } | ||||
|  | ||||
|   static __host__ __device__ nv_bfloat16 inline float2num(const float x) { | ||||
|     return __float2bfloat16(x); | ||||
|   } | ||||
| #endif | ||||
| }; | ||||
|  | ||||
| }  // namespace gptq_marlin | ||||
|  | ||||
| #endif | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	