mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
36 Commits
il_tool
...
v0.10.2rc1
Author | SHA1 | Date | |
---|---|---|---|
b8a93076d3 | |||
c3f9773b2c | |||
3707cb2505 | |||
920ed46b09 | |||
15cb047e25 | |||
9ad0688e43 | |||
b9a1c4c8a2 | |||
1aa427fdc1 | |||
1c63a16b65 | |||
922d3b401b | |||
19332c0479 | |||
a55cf41a09 | |||
6fb2788163 | |||
3d2a2de8f7 | |||
1116590b16 | |||
ccb97338af | |||
45c9cb5835 | |||
e283976f3a | |||
46876dff32 | |||
1823a00d67 | |||
ed16d0f26f | |||
0cdd213641 | |||
948dd3443b | |||
b2f7745774 | |||
82dfb12e52 | |||
bba1042c6f | |||
b6fbc15634 | |||
3e0d4a3475 | |||
562663a044 | |||
ed1623a88a | |||
13b89bd823 | |||
22a0070530 | |||
170129eb28 | |||
955c624915 | |||
4f87abdcc6 | |||
6910b56da2 |
@ -149,3 +149,25 @@ steps:
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build and publish nightly multi-arch image to DockerHub"
|
||||
depends_on:
|
||||
- create-multi-arch-manifest
|
||||
if: build.env("NIGHTLY") == "1"
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly"
|
||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
||||
- "docker push vllm/vllm-openai:nightly"
|
||||
- "docker push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
||||
# Clean up old nightly builds (keep only last 14)
|
||||
- "bash .buildkite/scripts/cleanup-nightly-builds.sh"
|
||||
plugins:
|
||||
- docker-login#v3.0.0:
|
||||
username: vllmbot
|
||||
password-env: DOCKERHUB_TOKEN
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
97
.buildkite/scripts/cleanup-nightly-builds.sh
Executable file
97
.buildkite/scripts/cleanup-nightly-builds.sh
Executable file
@ -0,0 +1,97 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
# Clean up old nightly builds from DockerHub, keeping only the last 14 builds
|
||||
# This script uses DockerHub API to list and delete old tags with "nightly-" prefix
|
||||
|
||||
# DockerHub API endpoint for vllm/vllm-openai repository
|
||||
REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags"
|
||||
|
||||
# Get DockerHub token from environment
|
||||
if [ -z "$DOCKERHUB_TOKEN" ]; then
|
||||
echo "Error: DOCKERHUB_TOKEN environment variable is not set"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Function to get all tags from DockerHub
|
||||
get_all_tags() {
|
||||
local page=1
|
||||
local all_tags=""
|
||||
|
||||
while true; do
|
||||
local response=$(curl -s -H "Authorization: Bearer $DOCKERHUB_TOKEN" \
|
||||
"$REPO_API_URL?page=$page&page_size=100")
|
||||
|
||||
# Get both last_updated timestamp and tag name, separated by |
|
||||
local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"')
|
||||
|
||||
if [ -z "$tags" ]; then
|
||||
break
|
||||
fi
|
||||
|
||||
all_tags="$all_tags$tags"$'\n'
|
||||
page=$((page + 1))
|
||||
done
|
||||
|
||||
# Sort by timestamp (newest first) and extract just the tag names
|
||||
echo "$all_tags" | sort -r | cut -d'|' -f2
|
||||
}
|
||||
|
||||
delete_tag() {
|
||||
local tag_name="$1"
|
||||
echo "Deleting tag: $tag_name"
|
||||
|
||||
local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name"
|
||||
local response=$(curl -s -X DELETE -H "Authorization: Bearer $DOCKERHUB_TOKEN" "$delete_url")
|
||||
|
||||
if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then
|
||||
echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')"
|
||||
else
|
||||
echo "Successfully deleted tag: $tag_name"
|
||||
fi
|
||||
}
|
||||
|
||||
# Get all nightly- prefixed tags, sorted by last_updated timestamp (newest first)
|
||||
echo "Fetching all tags from DockerHub..."
|
||||
all_tags=$(get_all_tags)
|
||||
|
||||
if [ -z "$all_tags" ]; then
|
||||
echo "No tags found to clean up"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Count total tags
|
||||
total_tags=$(echo "$all_tags" | wc -l)
|
||||
echo "Found $total_tags tags"
|
||||
|
||||
# Keep only the last 14 builds (including the current one)
|
||||
tags_to_keep=14
|
||||
tags_to_delete=$((total_tags - tags_to_keep))
|
||||
|
||||
if [ $tags_to_delete -le 0 ]; then
|
||||
echo "No tags need to be deleted (only $total_tags tags found, keeping $tags_to_keep)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Will delete $tags_to_delete old tags, keeping the newest $tags_to_keep"
|
||||
|
||||
# Get tags to delete (skip the first $tags_to_keep tags)
|
||||
tags_to_delete_list=$(echo "$all_tags" | tail -n +$((tags_to_keep + 1)))
|
||||
|
||||
if [ -z "$tags_to_delete_list" ]; then
|
||||
echo "No tags to delete"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Delete old tags
|
||||
echo "Deleting old tags..."
|
||||
while IFS= read -r tag; do
|
||||
if [ -n "$tag" ]; then
|
||||
delete_tag "$tag"
|
||||
# Add a small delay to avoid rate limiting
|
||||
sleep 1
|
||||
fi
|
||||
done <<< "$tags_to_delete_list"
|
||||
|
||||
echo "Cleanup completed successfully"
|
@ -379,11 +379,7 @@ steps:
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
||||
- pytest -v -s compile/piecewise/
|
||||
|
||||
- label: PyTorch Fullgraph Test # 20min
|
||||
timeout_in_minutes: 30
|
||||
|
14
.github/mergify.yml
vendored
14
.github/mergify.yml
vendored
@ -273,6 +273,20 @@ pull_request_rules:
|
||||
users:
|
||||
- "sangstar"
|
||||
|
||||
- name: assign reviewer for modelopt changes
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/model_executor/layers/quantization/modelopt\.py$
|
||||
- files~=^vllm/model_executor/layers/quantization/__init__\.py$
|
||||
- files~=^tests/models/quantization/test_modelopt\.py$
|
||||
- files~=^tests/quantization/test_modelopt\.py$
|
||||
- files~=^tests/models/quantization/test_nvfp4\.py$
|
||||
- files~=^docs/features/quantization/modelopt\.md$
|
||||
actions:
|
||||
assign:
|
||||
users:
|
||||
- "Edwardf0t1"
|
||||
|
||||
- name: remove 'needs-rebase' label when conflict is resolved
|
||||
conditions:
|
||||
- -conflict
|
||||
|
2
.github/workflows/add_label_automerge.yml
vendored
2
.github/workflows/add_label_automerge.yml
vendored
@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Add label
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.addLabels({
|
||||
|
2
.github/workflows/cleanup_pr_body.yml
vendored
2
.github/workflows/cleanup_pr_body.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
|
2
.github/workflows/issue_autolabel.yml
vendored
2
.github/workflows/issue_autolabel.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Label issues based on keywords
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
// Configuration: Add new labels and keywords here
|
||||
|
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
|
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Remind to run full CI on PR
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
try {
|
||||
|
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
actions: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0
|
||||
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
||||
with:
|
||||
# Increasing this value ensures that changes to this workflow
|
||||
# propagate to all issues and PRs in days rather than months
|
||||
|
12
.gitignore
vendored
12
.gitignore
vendored
@ -4,7 +4,7 @@
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# triton jit
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
@ -177,6 +177,14 @@ cython_debug/
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# Claude
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
|
||||
# Codex
|
||||
AGENTS.md
|
||||
.codex/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
@ -209,4 +217,4 @@ shellcheck*/
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
ep_kernels_workspace/
|
||||
|
@ -694,7 +694,7 @@ python -m vllm.entrypoints.openai.api_server \
|
||||
Send requests with images:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
@ -721,7 +721,7 @@ python -m vllm.entrypoints.openai.api_server \
|
||||
Send requests with videos:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
|
@ -1,191 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
||||
)
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"benchmark_latency.py is deprecated and will be removed in a "
|
||||
"future version. Please use 'vllm bench latency' instead.",
|
||||
)
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len + args.output_len
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f"{percentage}% percentile latency: {percentile} seconds")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
|
||||
|
||||
def create_argument_parser():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the latency of processing a single batch of "
|
||||
"requests till completion."
|
||||
)
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# V1 enables prefix caching by default which skews the latency
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
|
||||
return parser
|
||||
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
|
||||
raise OSError(
|
||||
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
|
||||
"Please set it to a valid path to use torch profiler."
|
||||
)
|
||||
main(args)
|
||||
print("""DEPRECATED: This script has been moved to the vLLM CLI.
|
||||
|
||||
Please use the following command instead:
|
||||
vllm bench latency
|
||||
|
||||
For help with the new command, run:
|
||||
vllm bench latency --help
|
||||
|
||||
Alternatively, you can run the new command directly with:
|
||||
python -m vllm.entrypoints.cli.main bench latency --help
|
||||
""")
|
||||
sys.exit(1)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,741 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark offline inference throughput."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from benchmark_dataset import (
|
||||
AIMODataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
RandomDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data,
|
||||
)
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(
|
||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
||||
)
|
||||
)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
||||
)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
start = time.perf_counter()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = await llm.get_model_config()
|
||||
assert all(
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data,
|
||||
)
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(
|
||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
||||
)
|
||||
)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp, lr) in enumerate(
|
||||
zip(prompts, sampling_params, lora_requests)
|
||||
):
|
||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (
|
||||
max(max_prompt_len, next_prompt_len)
|
||||
+ max(max_output_len, next_output_len)
|
||||
) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_mii(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import client, serve
|
||||
|
||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [request.prompt for request in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, max_new_tokens=output_len)
|
||||
end = time.perf_counter()
|
||||
client = client(model)
|
||||
client.terminate_server()
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"requests_per_second": [results["requests_per_second"]],
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
common_kwargs["no_stream"] = args.no_stream
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"benchmark_throughput.py is deprecated and will be removed in a "
|
||||
"future version. Please use 'vllm bench throughput' instead.",
|
||||
)
|
||||
def main(args: argparse.Namespace):
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
args.disable_frontend_multiprocessing,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
)
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(
|
||||
requests,
|
||||
args.model,
|
||||
tokenizer,
|
||||
args.n,
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(
|
||||
requests, args.model, args.tensor_parallel_size, args.output_len
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += (
|
||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
)
|
||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print(
|
||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
||||
)
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
||||
)
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print("When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = "random"
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm-chat", (
|
||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
||||
) # noqa: E501
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
) # noqa: E501
|
||||
else:
|
||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != "random" and args.random_range_ratio is not None:
|
||||
warnings.warn(
|
||||
"--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if (
|
||||
args.dataset_name not in {"random", "sonnet", None}
|
||||
and args.prefix_len is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if (
|
||||
args.backend in {"hf", "mii"}
|
||||
and getattr(args, "quantization", None) is not None
|
||||
):
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
def create_argument_parser():
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-engine",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"
|
||||
),
|
||||
)
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the LoRA adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help=f"Number of prefix tokens to be used in RandomDataset "
|
||||
"and SonnetDataset. For RandomDataset, the total input "
|
||||
"length is the sum of prefix-len (default: "
|
||||
f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
|
||||
"sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]. For SonnetDataset, "
|
||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||
"controls how much of the input is fixed lines versus "
|
||||
"random lines, but the total input length remains approximately "
|
||||
"input_len tokens.",
|
||||
)
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
|
||||
"for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to "
|
||||
"define a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
# hf dataset
|
||||
parser.add_argument(
|
||||
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split", type=str, default=None, help="Split of the HF dataset."
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
validate_args(args)
|
||||
main(args)
|
||||
print("""DEPRECATED: This script has been moved to the vLLM CLI.
|
||||
|
||||
Please use the following command instead:
|
||||
vllm bench throughput
|
||||
|
||||
For help with the new command, run:
|
||||
vllm bench throughput --help
|
||||
|
||||
Alternatively, you can run the new command directly with:
|
||||
python -m vllm.entrypoints.cli.main bench throughput --help
|
||||
""")
|
||||
sys.exit(1)
|
||||
|
@ -259,6 +259,7 @@ if __name__ == "__main__":
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
@ -274,6 +274,7 @@ if __name__ == "__main__":
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
@ -47,6 +47,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
|
||||
|
||||
# -----------------------
|
||||
@ -71,7 +72,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
@ -100,6 +101,7 @@ ARG COMMON_WORKDIR
|
||||
# Copy over the benchmark scripts as well
|
||||
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
|
||||
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
|
||||
COPY --from=export_vllm /docker ${COMMON_WORKDIR}/vllm/docker
|
||||
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
@ -1,18 +1,16 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
|
||||
ARG HIPBLASLT_BRANCH="db8e93b4"
|
||||
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.4.1-complete
|
||||
ARG HIPBLASLT_BRANCH="aa0bda7b"
|
||||
ARG HIPBLAS_COMMON_BRANCH="9b80ba8e"
|
||||
ARG LEGACY_HIPBLASLT_OPTION=
|
||||
ARG RCCL_BRANCH="648a58d"
|
||||
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||
ARG TRITON_BRANCH="e5be006"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
ARG PYTORCH_BRANCH="295f2ed4"
|
||||
ARG PYTORCH_BRANCH="f717b2af"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.21.0"
|
||||
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="916bf3c"
|
||||
ARG AITER_BRANCH="4822e675"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
@ -45,7 +43,7 @@ RUN apt-get update -y \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython
|
||||
RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython
|
||||
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH
|
||||
@ -53,6 +51,7 @@ ARG HIPBLAS_COMMON_BRANCH
|
||||
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||
RUN apt-get remove -y hipblaslt && apt-get autoremove -y && apt-get autoclean -y
|
||||
RUN cd hipBLAS-common \
|
||||
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||
&& mkdir build \
|
||||
@ -69,24 +68,17 @@ RUN cd hipBLASLt \
|
||||
&& make package
|
||||
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
RUN git clone ${RCCL_REPO}
|
||||
RUN cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
||||
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
RUN git clone ${TRITON_REPO}
|
||||
RUN cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
||||
&& if [ ! -f setup.py ]; then cd python; fi \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& mkdir -p /app/install && cp dist/*.whl /app/install
|
||||
RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \
|
||||
&& python3 -m build --wheel && cp dist/*.whl /app/install; fi
|
||||
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
@ -132,15 +124,25 @@ RUN cd aiter \
|
||||
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
|
||||
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
|
||||
|
||||
FROM base AS debs
|
||||
RUN mkdir /app/debs
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
cp /install/*.deb /app/debs
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
|
||||
FROM base AS final
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
||||
&& perl -p -i -e 's/, hipblas-common-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \
|
||||
&& perl -p -i -e 's/, hipblaslt-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \
|
||||
&& perl -p -i -e 's/, hipblaslt \([^)]*?\), /, /g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
@ -154,8 +156,6 @@ ARG BASE_IMAGE
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
ARG PYTORCH_BRANCH
|
||||
@ -170,8 +170,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
|
||||
&& echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
|
||||
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
|
||||
@ -180,4 +178,4 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
@ -169,7 +169,7 @@ All Llama 3.1, 3.2 and 4 models should be supported.
|
||||
|
||||
The tool calling that is supported is the [JSON-based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for Llama 4 models, it is recommended to use the `llama4_pythonic` tool parser.
|
||||
|
||||
Other tool calling formats like the built in python tool calling or custom tool calling are not supported.
|
||||
Other tool calling formats like the built-in python tool calling or custom tool calling are not supported.
|
||||
|
||||
Known issues:
|
||||
|
||||
|
@ -119,7 +119,7 @@ Currently, there are no pre-built ROCm wheels.
|
||||
This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation.
|
||||
|
||||
!!! tip
|
||||
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
||||
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm-up step before collecting perf numbers.
|
||||
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
|
||||
- To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention.
|
||||
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
|
||||
|
@ -322,6 +322,7 @@ th {
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -765,6 +766,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | |
|
||||
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
|
||||
### Pooling Models
|
||||
|
||||
|
@ -40,6 +40,34 @@ If other strategies don't solve the problem, it's likely that the vLLM instance
|
||||
- `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL.
|
||||
- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time.
|
||||
|
||||
## Breakpoints
|
||||
|
||||
Setting normal `pdb` breakpoints may not work in vLLM's codebase if they are executed in a subprocess. You will experience something like:
|
||||
|
||||
``` text
|
||||
File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 100, in trace_dispatch
|
||||
return self.dispatch_line(frame)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 125, in dispatch_line
|
||||
if self.quitting: raise BdbQuit
|
||||
^^^^^^^^^^^^^
|
||||
bdb.BdbQuit
|
||||
```
|
||||
|
||||
One solution is using [forked-pdb](https://github.com/Lightning-AI/forked-pdb). Install with `pip install fpdb` and set a breakpoint with something like:
|
||||
|
||||
``` python
|
||||
__import__('fpdb').ForkedPdb().set_trace()
|
||||
```
|
||||
|
||||
Another option is to disable multiprocessing entirely, with the `VLLM_ENABLE_V1_MULTIPROCESSING` environment variable.
|
||||
This keeps the scheduler in the same process, so you can use stock `pdb` breakpoints:
|
||||
|
||||
``` python
|
||||
import os
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
```
|
||||
|
||||
## Incorrect network setup
|
||||
|
||||
The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as `DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl` and the IP address should be the correct one.
|
||||
|
@ -28,12 +28,15 @@ Learn more about Ray placement groups:
|
||||
https://docs.ray.io/en/latest/placement-groups.html
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import zmq
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
@ -86,20 +89,72 @@ class RayTrainingActor:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(0)
|
||||
self.zmq_context = zmq.Context()
|
||||
self.zmq_address_counter = 0
|
||||
self.zmq_handle = None
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
return self.device_uuid
|
||||
|
||||
def get_weight_ipc_handles(self):
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
def get_zmq_handles(self) -> dict[str, str]:
|
||||
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
|
||||
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
|
||||
self.zmq_address_counter += 1
|
||||
return {self.device_uuid: self.zmq_handle}
|
||||
|
||||
data = {}
|
||||
for name, p in self.model.named_parameters():
|
||||
# A training actor might hold only a subset of the weights and may
|
||||
# need to gather weights from other actors. For demonstration
|
||||
# purposes, each training actor owns the full weight set.
|
||||
data[name] = reduce_tensor(p.detach())
|
||||
return {self.device_uuid: data}
|
||||
def update_weights(self):
|
||||
# align size to avoid misaligned address
|
||||
align_size = 256
|
||||
|
||||
def get_size(p: torch.Tensor) -> int:
|
||||
return (p.nbytes + align_size - 1) // align_size * align_size
|
||||
|
||||
named_parameters: dict[str, torch.nn.Parameter] = dict(
|
||||
self.model.named_parameters()
|
||||
)
|
||||
max_tensor_size = max(get_size(p) for p in named_parameters.values())
|
||||
# use max_tensor_size * 2 as buffer size
|
||||
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
|
||||
s = self.zmq_context.socket(zmq.REQ)
|
||||
s.bind(self.zmq_handle)
|
||||
handle = reduce_tensor(buffer)
|
||||
|
||||
offset = 0
|
||||
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
|
||||
named_tensors: list[dict] = []
|
||||
real_tensors: list[torch.Tensor] = []
|
||||
for name, p in named_parameters.items():
|
||||
size = get_size(p)
|
||||
if offset + size > buffer.numel():
|
||||
buckets.append((named_tensors, real_tensors))
|
||||
named_tensors, real_tensors = [], []
|
||||
offset = 0
|
||||
# assume tensors are contiguous
|
||||
named_tensors.append(
|
||||
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
|
||||
)
|
||||
real_tensors.append(p)
|
||||
offset += size
|
||||
if named_tensors:
|
||||
buckets.append((named_tensors, real_tensors))
|
||||
s.send_pyobj(handle)
|
||||
s.recv()
|
||||
for named_tensors, real_tensors in buckets:
|
||||
offset = 0
|
||||
for p in real_tensors:
|
||||
buffer[offset : offset + p.nbytes].data.copy_(
|
||||
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
|
||||
)
|
||||
offset += get_size(p)
|
||||
torch.cuda.synchronize()
|
||||
s.send_pyobj(named_tensors)
|
||||
s.recv()
|
||||
s.send_pyobj(None)
|
||||
s.recv()
|
||||
s.close()
|
||||
del buffer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Ray manages four GPUs.
|
||||
@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
||||
# the second inference engine.
|
||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
||||
|
||||
print("Gather all the IPC handles from the training actors.")
|
||||
ipc_handles = {}
|
||||
print("Gather all the ZMQ handles from the training actors.")
|
||||
zmq_handles = {}
|
||||
for actor in training_actors:
|
||||
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
|
||||
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
|
||||
|
||||
print(f"ZMQ handles: {zmq_handles}")
|
||||
|
||||
print("Update the weights of the inference engines.")
|
||||
for llm in inference_engines:
|
||||
ray.get(
|
||||
llm.collective_rpc.remote(
|
||||
"update_weights_from_ipc_handles", args=(ipc_handles,)
|
||||
)
|
||||
)
|
||||
ray.get(
|
||||
[actor.update_weights.remote() for actor in training_actors]
|
||||
+ [
|
||||
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
|
||||
for llm in inference_engines
|
||||
]
|
||||
)
|
||||
|
||||
print("Check if the weights are updated.")
|
||||
for llm in inference_engines:
|
||||
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
||||
|
@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from typing import Callable, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
|
||||
@ -66,6 +70,27 @@ class WorkerExtension:
|
||||
return weights_updated
|
||||
|
||||
|
||||
def rebuild_ipc(
|
||||
handle: tuple[Callable, tuple], device_id: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
if device_id is not None:
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
buffer = func(*list_args)
|
||||
return buffer
|
||||
|
||||
|
||||
class FlattenedTensorMetadata(TypedDict):
|
||||
name: str
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
# specify the start offset of this tensor in shared ipc_buffer tensor
|
||||
offset: int
|
||||
|
||||
|
||||
class ColocateWorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||
@ -76,27 +101,62 @@ class ColocateWorkerExtension:
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
|
||||
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||
|
||||
assert self.device is not None
|
||||
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
||||
self._zmq_ctx = zmq.Context()
|
||||
socket = self._zmq_ctx.socket(zmq.REP)
|
||||
socket.connect(zmq_handles[self.report_device_id()])
|
||||
buffer: Optional[torch.Tensor] = None
|
||||
while True:
|
||||
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
|
||||
socket.recv_pyobj()
|
||||
)
|
||||
if payload is None:
|
||||
# means the update is done
|
||||
process_weights_after_loading(
|
||||
self.model_runner.model, self.model_config, self.device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
socket.send(b"")
|
||||
break
|
||||
if isinstance(payload, tuple):
|
||||
# an ipc handle that vLLM can use `func, args = handle`
|
||||
# and `func(*args)` to rebuild GPU tensor.
|
||||
buffer = rebuild_ipc(payload, self.device.index)
|
||||
assert buffer.dtype == torch.uint8
|
||||
socket.send(b"")
|
||||
continue
|
||||
assert isinstance(payload, list)
|
||||
assert buffer is not None
|
||||
weights = []
|
||||
for item in payload:
|
||||
shape = item["shape"]
|
||||
if isinstance(shape, (list, tuple)):
|
||||
shape = torch.Size(shape)
|
||||
assert isinstance(shape, torch.Size)
|
||||
dtype, offset = item["dtype"], item["offset"]
|
||||
size = dtype.itemsize * shape.numel()
|
||||
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
|
||||
weights.append((item["name"], tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
del weights
|
||||
torch.cuda.synchronize()
|
||||
socket.send(b"")
|
||||
|
||||
socket.close()
|
||||
del buffer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
def update_weights_from_ipc_handles(self, ipc_handles):
|
||||
handles = ipc_handles[self.device_uuid]
|
||||
device_id = self.device.index
|
||||
weights = []
|
||||
for name, handle in handles.items():
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
tensor = func(*list_args)
|
||||
weights.append((name, tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
|
@ -6,6 +6,8 @@ import msgspec
|
||||
import zmq
|
||||
from msgspec.msgpack import Decoder
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
#
|
||||
# Types copied from vllm.distributed.kv_events
|
||||
@ -22,8 +24,8 @@ class KVCacheEvent(
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
block_hashes: list[BlockHash]
|
||||
parent_block_hash: Optional[BlockHash]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent):
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
block_hashes: list[BlockHash]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@
|
||||
<|system|>
|
||||
{{ system_message }}
|
||||
{%- if tools %}
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
In addition to plain text responses, you can choose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||
@ -19,7 +19,7 @@ If you decide to call functions:
|
||||
* prefix function calls with functools marker (no closing marker required)
|
||||
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||
* respect the argument type formatting. E.g., if the type is number and format is float, write value 7 as 7.0
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer == 0.11.3
|
||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines_core == 0.2.10
|
||||
outlines_core == 0.2.11
|
||||
# required for outlines backend disk cache
|
||||
diskcache == 5.6.3
|
||||
lark == 1.2.2
|
||||
|
@ -45,3 +45,34 @@ def test_bench_serve(server):
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve_chat(server):
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"serve",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
server.host,
|
||||
"--port",
|
||||
str(server.port),
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"32",
|
||||
"--random-output-len",
|
||||
"4",
|
||||
"--num-prompts",
|
||||
"5",
|
||||
"--endpoint",
|
||||
"/v1/chat/completions",
|
||||
"--endpoint-type",
|
||||
"openai-chat",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
|
@ -4,9 +4,9 @@
|
||||
Test (piecewise) compilation with a simple model where multiple submodules
|
||||
are compiled and graph captured separately.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
@ -15,10 +15,9 @@ from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
@ -26,27 +25,6 @@ HIDDEN_SIZE = 1024
|
||||
RANDOM_SEED = 0
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ParentModel(nn.Module):
|
||||
|
||||
|
@ -4,10 +4,10 @@
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@ -15,35 +15,9 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
global global_counter
|
||||
global_counter += 1
|
||||
print(f"{global_counter=}")
|
||||
out.copy_(q)
|
||||
out[0] += 1
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from ..silly_attention import get_global_counter, reset_global_counter
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@ -59,8 +33,7 @@ class SillyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Overall effect:
|
||||
x += 1
|
||||
x[0] += 2
|
||||
x = 3 * x + 19
|
||||
global_counter += 2
|
||||
"""
|
||||
x = x + 1
|
||||
@ -78,6 +51,7 @@ class SillyModel(nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_simple_piecewise_compile(use_inductor):
|
||||
assert VLLM_USE_V1
|
||||
|
||||
@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
reset_global_counter()
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
assert get_global_counter() == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||
|
@ -14,38 +14,15 @@ from typing import Any, Optional
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
||||
|
||||
@dataclass
|
||||
|
63
tests/compile/silly_attention.py
Normal file
63
tests/compile/silly_attention.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shared PyTorch custom silly attention for compilation tests.
|
||||
Centralizes custom operation definitions to avoid duplicate registrations.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# Shared library for all compilation test operations
|
||||
# Using "silly" namespace to match existing test expectations
|
||||
# import this file will automatically register
|
||||
# torch ops for testing (like silly.attention)
|
||||
silly_lib = Library("silly", "FRAGMENT")
|
||||
|
||||
# Global counter that counts the number of times attention is invoked
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def get_global_counter():
|
||||
"""Get the current global counter value"""
|
||||
return _global_counter
|
||||
|
||||
|
||||
def reset_global_counter():
|
||||
"""Reset the global counter to 0"""
|
||||
global _global_counter
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Unified attention implementation that depends on
|
||||
all inputs and affects the output.
|
||||
Always increments a global counter that tests can use or ignore.
|
||||
"""
|
||||
global _global_counter
|
||||
|
||||
# Always increment the global counter
|
||||
_global_counter += 1
|
||||
|
||||
# Unified implementation that depends on all inputs
|
||||
out.copy_(q + k + v)
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
"""Fake implementation for testing"""
|
||||
return
|
||||
|
||||
|
||||
# Register the unified attention operation
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
@ -10,36 +9,14 @@ from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
@ -151,7 +128,7 @@ def test_ignore_torch_compile_decorator():
|
||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
@ -173,7 +150,7 @@ class B(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||
cache_config.kv_sharing_fast_prefill)
|
||||
|
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import BaseIncrementalDetokenizer
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def include_stop_str_in_output(request):
|
||||
return request.param
|
||||
|
||||
|
||||
class _DummyDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
# Map token id to single ASCII character for deterministic testing.
|
||||
return chr(next_token_id)
|
||||
|
||||
|
||||
def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
||||
params = SamplingParams(
|
||||
stop=stop,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
min_tokens=min_tokens)
|
||||
# Keep other fields minimal for unit test purposes.
|
||||
req = EngineCoreRequest(
|
||||
request_id="test",
|
||||
prompt_token_ids=[],
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
return req
|
||||
|
||||
|
||||
def test_stop_string_while_stop_token_terminates(
|
||||
include_stop_str_in_output: bool):
|
||||
"""
|
||||
This test verifies that the detokenizer correctly handles the case where
|
||||
the generated token sequence contains both:
|
||||
- a stop token
|
||||
- an <eos> token
|
||||
|
||||
The detokenizer should respect the stop string and truncate the output
|
||||
accordingly.
|
||||
|
||||
Imagine the following sequence:
|
||||
- "abcdeZ" is generated, where "Z" is the <eos> token.
|
||||
- "cd" is the stop string.
|
||||
|
||||
If include_stop_str_in_output=False, the detokenizer should truncate the
|
||||
output to "ab" because the stop string "cd" is excluded.
|
||||
If include_stop_str_in_output=True, the detokenizer should include the stop
|
||||
string "cd" in the output, resulting in "abcd".
|
||||
|
||||
|
||||
This verifies the behavioral change introduced in BaseIncrementalDetokenizer
|
||||
where stop-string evaluation occurs before the early-return on
|
||||
stop_terminated.
|
||||
"""
|
||||
|
||||
# Generate text "abcdeZ" and tokenize it.
|
||||
generated_text = "abcde"
|
||||
eos_token = "Z"
|
||||
stop_string = "cd"
|
||||
generated_text = generated_text + eos_token
|
||||
token_ids = [ord(c) for c in generated_text]
|
||||
|
||||
# Create a request with the stop string and initialize the detokenizer.
|
||||
req = _make_request(stop=[stop_string],
|
||||
include_stop_str_in_output=include_stop_str_in_output)
|
||||
detok = _DummyDetokenizer(req)
|
||||
|
||||
# Simulate that the last token ('Z') is a stop token (stop_terminated=True).
|
||||
result = detok.update(new_token_ids=token_ids, stop_terminated=True)
|
||||
|
||||
# The update should not report a stop string
|
||||
assert result == stop_string
|
||||
|
||||
# Output text should reflect stop-string handling:
|
||||
# - include_stop_str_in_output=False => exclude "cd" => "ab"
|
||||
# - include_stop_str_in_output=True => include "cd" => "abcd"
|
||||
expected_text = "abcd" if include_stop_str_in_output else "ab"
|
||||
assert detok.output_text == expected_text
|
||||
|
||||
# The skipped final token should still be recorded in token_ids.
|
||||
assert detok.output_token_ids == token_ids
|
||||
|
||||
# get_next_output_text should return the full text when finished=True.
|
||||
# (Buffering only applies during streaming when finished=False.)
|
||||
assert detok.get_next_output_text(finished=True,
|
||||
delta=False) == expected_text
|
@ -25,7 +25,7 @@ class CustomUniExecutor(UniProcExecutor):
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None) -> list[Any]:
|
||||
# Drop marker to show that this was ran
|
||||
# Drop marker to show that this was run
|
||||
with open(".marker", "w"):
|
||||
...
|
||||
return super().collective_rpc(method, timeout, args, kwargs)
|
||||
|
@ -79,7 +79,7 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
# Need to re-import huggingface_hub
|
||||
# and friends to setup offline mode
|
||||
# and friends to set up offline mode
|
||||
_re_import_modules()
|
||||
# Cached model files should be used in offline mode
|
||||
for model_config in MODEL_CONFIGS:
|
||||
@ -136,7 +136,7 @@ def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch):
|
||||
disable_connect,
|
||||
)
|
||||
# Need to re-import huggingface_hub
|
||||
# and friends to setup offline mode
|
||||
# and friends to set up offline mode
|
||||
_re_import_modules()
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
LLM(**dataclasses.asdict(engine_args))
|
||||
|
@ -10,7 +10,7 @@ import pytest
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.renderer import BaseRenderer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -27,12 +27,16 @@ async def test_empty_prompt():
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(openai.BadRequestError,
|
||||
match="decoder prompt cannot be empty"):
|
||||
with pytest.raises(
|
||||
openai.BadRequestError,
|
||||
match=
|
||||
"Either prompt or prompt_embeds must be provided and non-empty."
|
||||
):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": []})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -83,7 +87,7 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
|
||||
buffer.seek(0)
|
||||
encoded_tensor = pybase64.b64encode(buffer.getvalue())
|
||||
|
||||
loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor)
|
||||
loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor)
|
||||
assert len(loaded_prompt_embeds) == 1
|
||||
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
|
||||
assert loaded_tensor.device.type == "cpu"
|
||||
|
@ -1,13 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pybase64
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.inputs.data import is_embeds_prompt
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -178,3 +182,132 @@ class TestRenderPrompt:
|
||||
with pytest.raises(ValueError, match="No tokenizer available"):
|
||||
await renderer_no_tokenizer.render_prompt(
|
||||
prompt_or_prompts="Hello world", max_length=100)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input_with_needs_detokenization(
|
||||
self, renderer, mock_async_tokenizer):
|
||||
# When needs_detokenization=True for token inputs, renderer should
|
||||
# use the async tokenizer to decode and include the original text
|
||||
# in the returned prompt object.
|
||||
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
|
||||
renderer.async_tokenizer_pool[
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
tokens = [1, 2, 3, 4]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=tokens,
|
||||
needs_detokenization=True,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
assert results[0]["prompt"] == "decoded text"
|
||||
mock_async_tokenizer.decode.assert_awaited_once()
|
||||
|
||||
|
||||
class TestRenderEmbedPrompt:
|
||||
|
||||
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
|
||||
"""Helper to create base64-encoded tensor bytes"""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
return pybase64.b64encode(buffer.read())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_prompt_embed(self, renderer):
|
||||
# Create a test tensor
|
||||
test_tensor = torch.randn(10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes, cache_salt="test_salt")
|
||||
|
||||
assert len(results) == 1
|
||||
assert is_embeds_prompt(results[0])
|
||||
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
|
||||
assert results[0]["cache_salt"] == "test_salt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_prompt_embeds(self, renderer):
|
||||
# Create multiple test tensors
|
||||
test_tensors = [
|
||||
torch.randn(8, 512, dtype=torch.float32),
|
||||
torch.randn(12, 512, dtype=torch.float32),
|
||||
]
|
||||
embed_bytes_list = [
|
||||
self._create_test_embed_bytes(t) for t in test_tensors
|
||||
]
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes_list)
|
||||
|
||||
assert len(results) == 2
|
||||
for i, result in enumerate(results):
|
||||
assert is_embeds_prompt(result)
|
||||
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_truncation(self, renderer):
|
||||
# Create tensor with more tokens than truncation limit
|
||||
test_tensor = torch.randn(20, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep last 10 tokens
|
||||
expected = test_tensor[-10:]
|
||||
assert torch.allclose(results[0]["prompt_embeds"], expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_different_dtypes(self, renderer):
|
||||
# Test different supported dtypes
|
||||
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
test_tensor = torch.randn(5, 256, dtype=dtype)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_embeds"].dtype == dtype
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
|
||||
# Test tensor with batch dimension gets squeezed
|
||||
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should be squeezed to 2D
|
||||
assert results[0]["prompt_embeds"].shape == (10, 768)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_prompts_and_embeds(self, renderer,
|
||||
mock_async_tokenizer):
|
||||
# Set up text tokenization
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||
[101, 102, 103])
|
||||
renderer.async_tokenizer_pool[
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
# Create embed
|
||||
test_tensor = torch.randn(5, 256, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
|
||||
|
||||
assert len(results) == 2
|
||||
# First should be embed prompt
|
||||
assert is_embeds_prompt(results[0])
|
||||
# Second should be tokens prompt
|
||||
assert "prompt_token_ids" in results[1]
|
||||
assert results[1]["prompt_token_ids"] == [101, 102, 103]
|
||||
|
@ -35,6 +35,7 @@ QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
@ -44,6 +45,7 @@ NUM_HEADS = [(64, 8), (40, 8)]
|
||||
HEAD_SIZE = [128]
|
||||
KV_LAYOUT = ["HND"] # currently only HND is supported
|
||||
BLOCK_SIZE = [16]
|
||||
WINDOW_LEFT = [-1, 127]
|
||||
SOFT_CAP = [None, 50.0]
|
||||
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@ -57,6 +59,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
window_left=window_left,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
window_left=window_left,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
rtol, atol = 4e-1, 1e0
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
|
@ -1247,7 +1247,7 @@ def baseline_scaled_mm(a: torch.Tensor,
|
||||
# then we would expand a to:
|
||||
# a = [[1, 1, 2, 2],
|
||||
# [3, 3, 4, 4]]
|
||||
# NOTE this function this function does not explicitly broadcast dimensions
|
||||
# NOTE this function does not explicitly broadcast dimensions
|
||||
# with an extent of 1, since this can be done implicitly by pytorch
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
|
@ -93,7 +93,7 @@ AITER_MODEL_LIST = [
|
||||
"allenai/OLMoE-1B-7B-0924-Instruct",
|
||||
marks=[pytest.mark.cpu_model],
|
||||
),
|
||||
pytest.param("swiss-ai/Apertus-8B"), # apertus
|
||||
pytest.param("swiss-ai/Apertus-8B-2509"), # apertus
|
||||
])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
|
@ -301,7 +301,7 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
|
||||
finished_requests_ids is larger than the maximum mamba block capacity.
|
||||
|
||||
This could generally happen due to the fact that hybrid does support
|
||||
statelessness mechanism where it can cleanup new incoming requests in
|
||||
statelessness mechanism where it can clean up new incoming requests in
|
||||
a single step.
|
||||
"""
|
||||
try:
|
||||
@ -322,7 +322,7 @@ def test_state_cleanup(
|
||||
This test is for verifying that the Hybrid state is cleaned up between
|
||||
steps.
|
||||
|
||||
If its not cleaned, an error would be expected.
|
||||
If it's not cleaned, an error would be expected.
|
||||
"""
|
||||
try:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
|
@ -9,6 +9,7 @@ import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
||||
check_embeddings_close)
|
||||
@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None,
|
||||
atol=MTEB_EMBED_TOL):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
example_prompts = ["The chef prepared a delicious meal."]
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
example_prompts = ["The chef prepared a delicious meal." * 1000]
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
|
||||
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
|
||||
# Confirm whether vllm is using the correct architecture
|
||||
if model_info.architecture:
|
||||
assert model_info.architecture in model_config.architectures
|
||||
|
||||
# Confirm whether vllm uses the correct default_pooling_type, which
|
||||
# relates to whether chunked prefill and prefix caching are enabled
|
||||
assert (model_config._model_info.default_pooling_type ==
|
||||
model_info.default_pooling_type)
|
||||
|
||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||
MTEB_EMBED_TASKS)
|
||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
vllm_outputs = vllm_model.embed(example_prompts,
|
||||
truncate_prompt_tokens=-1)
|
||||
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
|
||||
|
||||
# Accelerate mteb test by setting
|
||||
# SentenceTransformers mteb score to a constant
|
||||
if model_info.mteb_score is None:
|
||||
with hf_runner(model_info.name,
|
||||
is_sentence_transformer=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
# e.g. setting default parameters for the encode method of hf_runner
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
|
||||
hf_model_callback=None,
|
||||
vllm_mteb_encoder=VllmMtebEncoder,
|
||||
atol=MTEB_RERANK_TOL):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
|
||||
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
|
||||
# Confirm whether vllm is using the correct architecture
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture in model_config.architectures)
|
||||
|
||||
# Score API is only enabled for num_labels == 1
|
||||
assert model_config.hf_config.num_labels == 1
|
||||
|
||||
# Confirm whether vllm uses the correct default_pooling_type, which
|
||||
# relates to whether chunked prefill and prefix caching are enabled
|
||||
assert (model_config._model_info.default_pooling_type ==
|
||||
model_info.default_pooling_type)
|
||||
|
||||
@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
vllm_dtype = model_config.dtype
|
||||
|
||||
# Accelerate mteb test by setting
|
||||
# SentenceTransformers mteb score to a constant
|
||||
if model_info.mteb_score is None:
|
||||
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
||||
hf_runner, model_info.name, hf_model_callback)
|
||||
|
@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
||||
RERANK_MODELS = [
|
||||
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||
architecture="GemmaForSequenceClassification",
|
||||
mteb_score=0.33757,
|
||||
hf_overrides={
|
||||
"architectures":
|
||||
["GemmaForSequenceClassification"],
|
||||
|
@ -158,7 +158,7 @@ class _HfExamplesInfo:
|
||||
# yapf: disable
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B",
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-2509",
|
||||
min_transformers_version="4.56.0",
|
||||
trust_remote_code=True),
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
|
||||
|
@ -28,7 +28,7 @@ ACCURACY_CONFIGS = [
|
||||
expected_value=0.76), # no bias
|
||||
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
|
||||
# so only one of these tests can run in a single call to pytest. As
|
||||
# a follow up, move this into the LM-EVAL section of the CI.
|
||||
# a follow-up, move this into the LM-EVAL section of the CI.
|
||||
# GSM8KAccuracyTestConfig(
|
||||
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
|
||||
# expected_value=0.66), # bias in QKV layers
|
||||
|
@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||
(None, bool, [1, 2, 3])])
|
||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||
def test_sha256(input: tuple, output: int):
|
||||
hash = sha256(input)
|
||||
assert hash is not None
|
||||
assert isinstance(hash, int)
|
||||
assert hash != 0
|
||||
def test_sha256(input: tuple):
|
||||
digest = sha256(input)
|
||||
assert digest is not None
|
||||
assert isinstance(digest, bytes)
|
||||
assert digest != b""
|
||||
|
||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
|
||||
byteorder="big")
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert digest == hashlib.sha256(input_bytes).digest()
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert hash == sha256(input)
|
||||
assert digest == sha256(input)
|
||||
|
||||
# hashing different input, returns different value
|
||||
assert hash != sha256(input + (1, ))
|
||||
assert digest != sha256(input + (1, ))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -6,20 +6,22 @@ from typing import Callable, Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||
get_request_block_hasher, hash_block_tokens, init_none_hash,
|
||||
is_kv_cache_type_uniform, unify_kv_cache_configs)
|
||||
is_kv_cache_type_uniform, make_block_hash_with_group_id,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_none_hash(monkeypatch, hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
|
||||
@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
|
||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH != 0
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH != b""
|
||||
|
||||
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
|
||||
with monkeypatch.context() as m:
|
||||
@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
|
||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
|
||||
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
|
||||
|
||||
|
||||
def test_kv_cache_block():
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
|
||||
# Test KVCacheBlock initialization
|
||||
block = KVCacheBlock(block_id=0)
|
||||
@ -127,8 +128,7 @@ def test_kv_cache_block():
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
# Test block hash setting and resetting
|
||||
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123,
|
||||
token_ids=(1, 2, 3))
|
||||
block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0)
|
||||
block.block_hash = block_hash
|
||||
assert block.block_hash == block_hash
|
||||
|
||||
@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
||||
assert next_mm_idx == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_hash_block_tokens(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
parent_block_hash = 123
|
||||
parent_block_hash = BlockHash(b"123")
|
||||
curr_block_token_ids = (1, 2, 3)
|
||||
extra_keys = ("key1", "key2")
|
||||
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
curr_block_token_ids, extra_keys)
|
||||
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
assert block_hash.hash_value == hash_fn(
|
||||
(parent_block_hash, curr_block_token_ids, extra_keys))
|
||||
assert block_hash.token_ids == curr_block_token_ids
|
||||
assert block_hash.extra_keys == extra_keys
|
||||
expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
|
||||
assert block_hash == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_request_block_hasher(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
kv_cache_utils.init_none_hash(hash_fn)
|
||||
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
|
||||
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
|
||||
# Check the first block
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
assert block_hashes[0].extra_keys == ("hash1", )
|
||||
|
||||
# Check the second block
|
||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
||||
assert block_hashes[1].extra_keys == ("hash2", )
|
||||
assert block_hashes[0] == hash_fn(
|
||||
(kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
|
||||
assert block_hashes[1] == hash_fn(
|
||||
(block_hashes[0], (3, 4, 5), ("hash2", )))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_hash_tokens_different_mm_input(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
assert block_hashes1[1] != block_hashes2[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
kv_cache_utils.init_none_hash(hash_fn)
|
||||
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
block_hashes = request.block_hashes
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
assert block_hashes[0].extra_keys is None
|
||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
||||
assert block_hashes[1].extra_keys is None
|
||||
assert block_hashes[0] == hash_fn(
|
||||
(kv_cache_utils.NONE_HASH, (0, 1, 2), None))
|
||||
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
|
||||
|
||||
|
||||
def test_metrics():
|
||||
|
@ -8,17 +8,19 @@ from typing import Callable, Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.utils import sha256, sha256_cbor
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
get_block_hash, get_group_id,
|
||||
get_request_block_hasher,
|
||||
hash_block_tokens, init_none_hash)
|
||||
hash_block_tokens, init_none_hash,
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_prefill(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, 11),
|
||||
@ -110,10 +114,6 @@ def test_prefill(hash_algo):
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
# choose the hash function according to the parameter
|
||||
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
|
||||
sha256 if hash_algo == "sha256" else hash)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
|
||||
@ -137,10 +137,12 @@ def test_prefill(hash_algo):
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||
assert blk_hash is not None
|
||||
assert get_block_hash(blk_hash) == block_hash
|
||||
assert get_group_id(blk_hash) == 0
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, ):
|
||||
@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
hash_fn = hash
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
|
||||
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
for block_id in block_ids:
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
for group_id, block_id in enumerate(block_ids):
|
||||
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||
assert blk_hash is not None
|
||||
assert get_block_hash(blk_hash) == block_hash
|
||||
assert get_group_id(blk_hash) == group_id
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, 8, 12):
|
||||
@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block)
|
||||
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||
block_size, hash)
|
||||
block_size, sha256)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
hash_with_group_id)
|
||||
@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
|
||||
|
||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||
test_partial_request_hit("2", [
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2)
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2)
|
||||
], 3)
|
||||
|
||||
# Evict the first block of full attention, makes total cache miss.
|
||||
test_partial_request_hit("3", [
|
||||
BlockHashWithGroupId(block_hashes[0], 0),
|
||||
], 0)
|
||||
test_partial_request_hit(
|
||||
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0)
|
||||
|
||||
# Evict the last block of all layers, reduces the hit length to 2.
|
||||
test_partial_request_hit("4", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[2], 1),
|
||||
BlockHashWithGroupId(block_hashes[2], 2),
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[2], 1),
|
||||
make_block_hash_with_group_id(block_hashes[2], 2),
|
||||
], 2)
|
||||
|
||||
# Evict the last block of full attention, reduces the hit length to 2.
|
||||
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
|
||||
2)
|
||||
test_partial_request_hit(
|
||||
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
|
||||
2)
|
||||
test_partial_request_hit(
|
||||
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
|
||||
2)
|
||||
test_partial_request_hit(
|
||||
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2)
|
||||
|
||||
# Evict different set of blocks for full attention and sliding window makes
|
||||
# total cache miss.
|
||||
@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers have different hit length.
|
||||
test_partial_request_hit("8", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2),
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||
], 0)
|
||||
|
||||
|
||||
@ -372,8 +374,8 @@ def test_prefill_plp():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
# the default hash function is hash
|
||||
hash_fn = hash
|
||||
# the default hash function is sha256
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@ -404,10 +406,12 @@ def test_prefill_plp():
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
blk_hash = (manager.block_pool.blocks[block_id].block_hash)
|
||||
assert blk_hash is not None
|
||||
assert get_block_hash(blk_hash) == block_hash
|
||||
assert get_group_id(blk_hash) == 0
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, ):
|
||||
@ -493,7 +497,7 @@ def test_decode():
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
|
||||
hash)
|
||||
sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -538,7 +542,7 @@ def test_evict():
|
||||
)
|
||||
|
||||
last_token_id = 5 * 16 + 7
|
||||
req0 = make_request("0", list(range(last_token_id)), block_size, hash)
|
||||
req0 = make_request("0", list(range(last_token_id)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -550,7 +554,7 @@ def test_evict():
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
last_token_id + 3 * 16)), block_size,
|
||||
hash)
|
||||
sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -572,7 +576,7 @@ def test_evict():
|
||||
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
|
||||
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash)
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == ([1, 2], )
|
||||
assert num_computed_tokens == 2 * 16
|
||||
@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
|
||||
|
||||
# Allocate 1 block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
req = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
|
||||
|
||||
# Allocate a new block that's not full, make sure hash info on the
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)), block_size, hash)
|
||||
req = make_request("1", list(range(num_tokens - 1)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Allocate a block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
|
||||
block_size, hash)
|
||||
block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Now if we have a cache hit on the first block, we should evict the second
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash)
|
||||
req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert computed_blocks.blocks[0][0].block_id == 1
|
||||
@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
|
||||
)
|
||||
|
||||
req1 = make_request("1", list(range(10)), block_size,
|
||||
hash) # 2 blocks and some more
|
||||
sha256) # 2 blocks and some more
|
||||
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
|
||||
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16)), block_size,
|
||||
hash) # shared prefix
|
||||
sha256) # shared prefix
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)), block_size, hash)
|
||||
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
|
||||
assert not blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_cache_blocks(hash_fn):
|
||||
"""
|
||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||
@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
|
||||
# Block 1/5: [4, 5, 6, 7]
|
||||
# Block 2/6: [8, 9, 10, 11]
|
||||
# Block 3/7: [12, 13]
|
||||
req = make_request("0", list(range(14)), block_size, hash)
|
||||
req = make_request("0", list(range(14)), block_size, sha256)
|
||||
|
||||
# Cache the blocks for group 0.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
@ -845,6 +849,8 @@ def test_mm_prefix_caching():
|
||||
"""
|
||||
This tests that the multi-modal prefix caching is correct.
|
||||
"""
|
||||
kv_cache_utils.init_none_hash(sha256)
|
||||
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, 11),
|
||||
@ -874,23 +880,30 @@ def test_mm_prefix_caching():
|
||||
req0 = make_request("0",
|
||||
all_token_ids,
|
||||
block_size,
|
||||
hash,
|
||||
sha256,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
# Completed block should have hashes
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = req0.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("aaa", )
|
||||
assert block_hashes[1].extra_keys == ("aaa", "bbb")
|
||||
assert block_hashes[2].extra_keys == ("bbb", )
|
||||
assert block_hashes[0] == sha256(
|
||||
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
|
||||
("aaa", )))
|
||||
assert block_hashes[1] == sha256(
|
||||
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
|
||||
("aaa", "bbb")))
|
||||
assert block_hashes[2] == sha256(
|
||||
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
|
||||
("bbb", )))
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks is not None
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
@ -901,10 +914,10 @@ def test_mm_prefix_caching():
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# The just completed block should have hashes with extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
assert block_hashes[3].extra_keys == ("ccc", )
|
||||
assert block_hashes[3] == sha256(
|
||||
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
|
||||
("ccc", )))
|
||||
|
||||
# Cache hit.
|
||||
unique_token_ids = [-1] * 7 + [200] * 5
|
||||
@ -916,7 +929,7 @@ def test_mm_prefix_caching():
|
||||
req1 = make_request("1",
|
||||
all_token_ids,
|
||||
block_size,
|
||||
hash,
|
||||
sha256,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
@ -929,6 +942,8 @@ def test_cache_key_salting():
|
||||
This tests that cache salts are applied during hashing and the cache
|
||||
is separated cache as expected.
|
||||
"""
|
||||
kv_cache_utils.init_none_hash(sha256)
|
||||
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, 11),
|
||||
@ -939,21 +954,26 @@ def test_cache_key_salting():
|
||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
token_ids = common_token_ids + [3] * 11
|
||||
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1")
|
||||
req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
# Completed block should have hashes
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = req0.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("salt1", )
|
||||
assert block_hashes[1].extra_keys is None
|
||||
assert block_hashes[2].extra_keys is None
|
||||
assert block_hashes[0] == sha256(
|
||||
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
|
||||
assert block_hashes[1] == sha256(
|
||||
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
|
||||
assert block_hashes[2] == sha256(
|
||||
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
|
||||
None))
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks is not None
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
@ -964,14 +984,13 @@ def test_cache_key_salting():
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# Now one more block that should not have extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
assert block_hashes[3].extra_keys is None
|
||||
assert block_hashes[3] == sha256(
|
||||
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
|
||||
|
||||
# Test cache hit with a new request that has the same salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1")
|
||||
req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Should match only a prefix of 3 blocks.
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
@ -979,13 +998,19 @@ def test_cache_key_salting():
|
||||
|
||||
# Test cache miss with same content but different salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2")
|
||||
req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = req2.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("salt2", )
|
||||
assert block_hashes[0] == sha256(
|
||||
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
|
||||
assert block_hashes[1] == sha256(
|
||||
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
|
||||
assert block_hashes[2] == sha256(
|
||||
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
|
||||
None))
|
||||
|
||||
|
||||
def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# Complete 3 blocks (48 tokens)
|
||||
# | Common-0 | Common-1 | Common-2 | ... |
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
req0 = make_request("0", common_token_ids, block_size, hash)
|
||||
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2, block_size, hash)
|
||||
req1 = make_request("1", common_token_ids * 2, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks.blocks[0] == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||
req2 = make_request("2", [7] * block_size * 2, block_size, hash)
|
||||
req2 = make_request("2", [7] * block_size * 2, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
# In this case, the ref_cnt of the computed blocks should not be changed.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3, block_size, hash)
|
||||
req3 = make_request("3", common_token_ids * 3, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks.blocks[0] == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, block_size, hash)
|
||||
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req1 = make_request("1", all_token_ids, block_size, hash)
|
||||
req1 = make_request("1", all_token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(req1.block_hashes) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
|
||||
assert manager.prefix_cache_stats is None
|
||||
|
||||
# Call all functions that check whether log_stats is disabled.
|
||||
req = make_request("0", list(range(16)), block_size, hash)
|
||||
req = make_request("0", list(range(16)), block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
|
||||
|
||||
def test_maybe_evict_cached_block():
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
||||
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10,
|
||||
token_ids=(100, )),
|
||||
group_id=1000)
|
||||
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
|
||||
token_ids=(200, )),
|
||||
group_id=2000)
|
||||
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
|
||||
token_ids=(300, )),
|
||||
group_id=3000)
|
||||
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||
block_hashes = [
|
||||
block_hash0,
|
||||
block_hash1,
|
||||
@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||
_ = manager.allocate_slots(req0, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
# Should see block_to_cache number of removed block events and a new block
|
||||
# stored event
|
||||
manager.free(req0)
|
||||
req1 = make_request("1", list(range(num_tokens)), block_size, hash)
|
||||
req1 = make_request("1", list(range(num_tokens)), block_size, sha256)
|
||||
_ = manager.allocate_slots(req1, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
|
||||
# Request with 3 full blocks (48 tokens)
|
||||
token_ids = [0] * (3 * block_size)
|
||||
req = make_request("divisible_request", token_ids, block_size, hash)
|
||||
req = make_request("divisible_request", token_ids, block_size, sha256)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
manager.free(req)
|
||||
|
||||
# New request with same tokens + Eagle enabled
|
||||
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash)
|
||||
req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
|
||||
# Should retain 1 block:
|
||||
@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
|
||||
)
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
||||
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
|
||||
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
||||
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
|
||||
assert manager.block_pool.get_cached_block(
|
||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
BlockHashWithGroupId(block_hash_first_block, 0))
|
||||
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
|
||||
block_size, hash)
|
||||
block_size, sha256)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
|
||||
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||
# not considered. But after dropping the last matched block due to eagle,
|
||||
|
@ -6,8 +6,8 @@ import random
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock)
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
ChunkedLocalAttentionManager, SlidingWindowManager)
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
|
||||
def run_one_case(block_is_cached, tail_token, expect_length):
|
||||
block_hash_list = [
|
||||
BlockHash(i, ()) for i in range(len(block_is_cached))
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
||||
block_hash, 0)] = {
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
|
||||
@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
|
||||
|
||||
def run_one_case(block_is_cached, expect_length):
|
||||
block_hash_list = [
|
||||
BlockHash(i, ()) for i in range(len(block_is_cached))
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
||||
block_hash, 0)] = {
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
@ -130,10 +131,10 @@ def create_requests(
|
||||
) -> list[Request]:
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
init_none_hash(sha256)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
|
@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
|
||||
assert vllm_config.cache_config.enable_prefix_caching
|
||||
|
||||
# default hash algorithm is "builtin"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# set hash algorithm to sha256_cbor
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == \
|
||||
"sha256_cbor"
|
||||
|
||||
# set hash algorithm to sha256
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# set hash algorithm to builtin
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
|
||||
# an invalid hash algorithm raises an error
|
||||
parser.exit_on_error = False
|
||||
with pytest.raises(ArgumentError):
|
||||
|
@ -686,7 +686,7 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
|
||||
async def test_completion_with_empty_prompt_embeds(
|
||||
client: openai.AsyncOpenAI) -> None:
|
||||
"""Test completion with empty prompt embeds."""
|
||||
payload: dict[str, list] = {"prompt_embeds": []}
|
||||
payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []}
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
# base_url = http://localhost:8000/v1/completions
|
||||
response = requests.post(f"{client.base_url}completions",
|
||||
|
@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
@ -127,11 +128,11 @@ def create_request(request_id: int,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
hash_fn: Callable = hash) -> Request:
|
||||
hash_fn: Callable = sha256) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
init_none_hash(hash_fn)
|
||||
_none_hash_initialized = True
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
@ -29,7 +29,7 @@ run_mypy vllm/engine
|
||||
run_mypy vllm/executor
|
||||
run_mypy vllm/inputs
|
||||
run_mypy vllm/lora
|
||||
run_mypy vllm/model_executor
|
||||
run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor
|
||||
run_mypy vllm/plugins
|
||||
run_mypy vllm/worker
|
||||
run_mypy vllm/v1
|
||||
|
@ -17,6 +17,47 @@ from tqdm.asyncio import tqdm
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
class StreamedResponseHandler:
|
||||
"""Handles streaming HTTP responses by accumulating chunks until complete
|
||||
messages are available."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
|
||||
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
|
||||
"""Add a chunk of bytes to the buffer and return any complete
|
||||
messages."""
|
||||
chunk_str = chunk_bytes.decode("utf-8")
|
||||
self.buffer += chunk_str
|
||||
|
||||
messages = []
|
||||
|
||||
# Split by double newlines (SSE message separator)
|
||||
while "\n\n" in self.buffer:
|
||||
message, self.buffer = self.buffer.split("\n\n", 1)
|
||||
message = message.strip()
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
# if self.buffer is not empty, check if it is a complete message
|
||||
# by removing data: prefix and check if it is a valid JSON
|
||||
if self.buffer.startswith("data: "):
|
||||
message_content = self.buffer.removeprefix("data: ").strip()
|
||||
if message_content == "[DONE]":
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
elif message_content:
|
||||
try:
|
||||
json.loads(message_content)
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# Incomplete JSON, wait for more chunks.
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""The input for the request function."""
|
||||
@ -102,46 +143,50 @@ async def async_request_openai_completions(
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if chunk_bytes.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.removeprefix("data: ")
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
@ -227,41 +272,44 @@ async def async_request_openai_chat_completions(
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
handler = StreamedResponseHandler()
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if chunk_bytes.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.removeprefix("data: ")
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
@ -347,36 +395,40 @@ async def async_request_openai_audio(
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
chunk = message.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
|
@ -258,8 +258,10 @@ class AttnFusionPass(VllmInductorPass):
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C,
|
||||
"scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
|
@ -9,7 +9,6 @@ import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
@ -34,6 +33,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
@ -745,7 +745,7 @@ class ModelConfig:
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self.dtype = _get_and_verify_dtype(
|
||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
self.hf_config,
|
||||
self.dtype,
|
||||
@ -1751,6 +1751,32 @@ class ModelConfig:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
return getattr(self.hf_config, "use_pad_token", True)
|
||||
|
||||
@property
|
||||
def head_dtype(self) -> torch.dtype:
|
||||
"""
|
||||
"head" refers to the last Linear layer(s) of an LLM,
|
||||
such as the lm_head in a generation model,
|
||||
or the score or classifier in a classification model.
|
||||
|
||||
The default head_dtype based on runner_type.\n
|
||||
- The pooling model defaults to using fp32 head,
|
||||
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
|
||||
- The generate model defaults to not using fp32 head,
|
||||
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
|
||||
"""
|
||||
head_dtype = _get_head_dtype(config=self.hf_config,
|
||||
dtype=self.dtype,
|
||||
runner_type=self.runner_type)
|
||||
|
||||
if head_dtype not in current_platform.supported_dtypes:
|
||||
logger.warning_once(
|
||||
"The current platform does not support [%s] head dtype, "
|
||||
"fallback to model dtype [%s].", head_dtype, self.dtype)
|
||||
return self.dtype
|
||||
|
||||
logger.debug_once("head dtype: %s", head_dtype)
|
||||
return head_dtype
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
# Consider max_model_len in tokenizer_config only when
|
||||
# pooling models use absolute position_embedding.
|
||||
@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
|
||||
runner_type: str) -> torch.dtype:
|
||||
head_dtype: Optional[Union[str,
|
||||
torch.dtype]] = getattr(config, "head_dtype",
|
||||
None)
|
||||
|
||||
if head_dtype == "model":
|
||||
return dtype
|
||||
elif isinstance(head_dtype, str):
|
||||
head_dtype = head_dtype.lower()
|
||||
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype!r}")
|
||||
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
|
||||
elif isinstance(head_dtype, torch.dtype):
|
||||
return head_dtype
|
||||
elif head_dtype is None:
|
||||
if torch.float32 not in current_platform.supported_dtypes:
|
||||
return dtype
|
||||
if runner_type == "pooling":
|
||||
return torch.float32
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype}")
|
||||
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
tokenizer_config: Optional[dict],
|
||||
@ -3210,107 +3261,6 @@ class ObservabilityConfig:
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
kv_connector: Optional[str] = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: Optional[str] = None
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
kv_role: Optional[KVRole] = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||
|
||||
kv_rank: Optional[int] = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
P2pNcclConnector, this should be 2."""
|
||||
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
kv_connector_module_path: Optional[str] = None
|
||||
"""The Python module path to dynamically load the KV connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
f"is set, supported roles are {get_args(KVRole)}")
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
@ -3891,6 +3841,7 @@ class VllmConfig:
|
||||
f"load_format={self.load_config.load_format}, "
|
||||
f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa
|
||||
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
|
||||
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
|
||||
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
|
||||
f"quantization={self.model_config.quantization}, "
|
||||
f"enforce_eager={self.model_config.enforce_eager}, "
|
||||
|
@ -24,7 +24,7 @@ logger = init_logger(__name__)
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
|
||||
|
||||
@config
|
||||
@ -63,17 +63,12 @@ class CacheConfig:
|
||||
"""Sliding window size for the KV cache. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
|
||||
default for V1."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
|
||||
"""Whether to enable prefix caching. Enabled by default for V1."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "builtin" is Python's built-in hash.\n
|
||||
- "sha256" is collision resistant but with certain overheads.
|
||||
This option uses Pickle for object serialization before hashing.\n
|
||||
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
|
||||
hash. It serializes objects using canonical CBOR and hashes them with
|
||||
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
|
||||
digest."""
|
||||
- "sha256" uses Pickle for object serialization before hashing.\n
|
||||
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256."""
|
||||
cpu_offload_gb: float = 0
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
|
@ -546,7 +546,8 @@ class CompilationConfig:
|
||||
# full cudagraph outside the fx graph. This reduces some cpu
|
||||
# overhead when the runtime batch_size is not cudagraph captured.
|
||||
# see https://github.com/vllm-project/vllm/pull/20059 for details.
|
||||
self.splitting_ops = self._attention_ops
|
||||
# make a copy to avoid mutating the class-level list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
elif len(self.splitting_ops) == 0:
|
||||
logger.warning_once("Using piecewise compilation with empty "
|
||||
"splitting_ops.")
|
||||
@ -561,6 +562,18 @@ class CompilationConfig:
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
|
||||
if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput":
|
||||
# exclude MoE dispatch/combine from capture by ensuring
|
||||
# piecewise splitting includes them, so communication remains
|
||||
# outside CUDA graphs while compute can still be graphed.
|
||||
moe_ops = [
|
||||
"vllm.moe_forward",
|
||||
"vllm.moe_forward_shared",
|
||||
]
|
||||
for op in moe_ops:
|
||||
if op not in self.splitting_ops:
|
||||
self.splitting_ops.append(op)
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
return self.splitting_ops is not None and all(
|
||||
op in self.splitting_ops for op in self._attention_ops)
|
||||
|
111
vllm/config/kv_transfer.py
Normal file
111
vllm/config/kv_transfer.py
Normal file
@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, Optional, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
kv_connector: Optional[str] = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: Optional[str] = None
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
kv_role: Optional[KVRole] = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||
|
||||
kv_rank: Optional[int] = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
P2pNcclConnector, this should be 2."""
|
||||
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
kv_connector_module_path: Optional[str] = None
|
||||
"""The Python module path to dynamically load the KV connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
f"is set, supported roles are {get_args(KVRole)}")
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
@ -16,6 +16,7 @@ import zmq
|
||||
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
parent_block_hash: Optional[ExternalBlockHash]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
|
@ -14,7 +14,8 @@ from vllm.logger import init_logger
|
||||
# yapf: enable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
|
@ -15,7 +15,7 @@ import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
|
@ -13,7 +13,7 @@ import zmq
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import join_host_port, make_zmq_path, split_host_port
|
||||
|
@ -20,7 +20,7 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
|
@ -1117,7 +1117,7 @@ def initialize_model_parallel(
|
||||
"decode context model parallel group is already initialized")
|
||||
# Note(hc): In the current implementation of decode context parallel,
|
||||
# dcp_size must not exceed tp_size, because the world size does not
|
||||
# change by DCP, it simply reuse the GPUs of TP group, and split one
|
||||
# change by DCP, it simply reuses the GPUs of TP group, and split one
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size).unbind(0)
|
||||
|
@ -1592,20 +1592,12 @@ class EngineArgs:
|
||||
"in low performance due to small KV cache size. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
|
||||
# if using prefix caching, we must set a hash algo
|
||||
if self.enable_prefix_caching:
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if model_config.is_multimodal_model:
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# VLLM_V0 only supports builtin hash algo for prefix caching.
|
||||
if self.prefix_caching_hash_algo == "sha256":
|
||||
raise ValueError(
|
||||
"sha256 is not supported for prefix caching in V0 engine. "
|
||||
"Please use 'builtin'.")
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if self.enable_prefix_caching and model_config.is_multimodal_model:
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# Set max_num_seqs to 256 for VLLM_V0.
|
||||
if self.max_num_seqs is None:
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
@ -57,9 +59,14 @@ class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
@ -89,9 +96,13 @@ class SimpleContext(ConversationContext):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
@ -103,6 +114,7 @@ class HarmonyContext(ConversationContext):
|
||||
self._messages = messages
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
@ -234,7 +246,8 @@ class HarmonyContext(ConversationContext):
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (recipient.startswith("browser.")
|
||||
or recipient.startswith("python"))
|
||||
or recipient.startswith("python") or
|
||||
recipient.startswith("container."))
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
@ -248,6 +261,9 @@ class HarmonyContext(ConversationContext):
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
@ -256,6 +272,7 @@ class HarmonyContext(ConversationContext):
|
||||
async def call_search_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
@ -265,12 +282,16 @@ class HarmonyContext(ConversationContext):
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author, content=[content], recipient=Role.ASSISTANT)
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
]
|
||||
|
||||
async def call_python_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
@ -290,13 +311,63 @@ class HarmonyContext(ConversationContext):
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
self._tool_sessions[
|
||||
tool_name] = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name))
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id))
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info("Cleaning up tool session for %s",
|
||||
tool_session._client_info)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools))
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
@ -16,11 +16,13 @@ from openai.types.responses.response_function_web_search import (
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent)
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, Message, ReasoningEffort,
|
||||
Role, StreamableParser, SystemContent, TextContent,
|
||||
ToolDescription, load_harmony_encoding)
|
||||
from openai_harmony import (Author, ChannelConfig, Conversation,
|
||||
DeveloperContent, HarmonyEncodingName, Message,
|
||||
ReasoningEffort, Role, StreamableParser,
|
||||
SystemContent, TextContent, ToolDescription,
|
||||
load_harmony_encoding)
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||
ResponseInputOutputItem)
|
||||
from vllm.utils import random_uuid
|
||||
@ -33,6 +35,20 @@ REASONING_EFFORT = {
|
||||
|
||||
_harmony_encoding = None
|
||||
|
||||
# Builtin tools that should be included in the system message when
|
||||
# they are available and requested by the user.
|
||||
# Tool args are provided by MCP tool descriptions. Output
|
||||
# of the tools are stringified.
|
||||
BUILTIN_TOOLS = {
|
||||
"web_search_preview",
|
||||
"code_interpreter",
|
||||
"container",
|
||||
}
|
||||
|
||||
|
||||
def has_custom_tools(tool_types: list[str]) -> bool:
|
||||
return not set(tool_types).issubset(BUILTIN_TOOLS)
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _harmony_encoding
|
||||
@ -48,10 +64,19 @@ def get_system_message(
|
||||
start_date: Optional[str] = None,
|
||||
browser_description: Optional[str] = None,
|
||||
python_description: Optional[str] = None,
|
||||
container_description: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
with_custom_tools: bool = False,
|
||||
) -> Message:
|
||||
sys_msg_content = SystemContent.new()
|
||||
if model_identity is not None:
|
||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||
if (instructions is not None
|
||||
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
||||
current_identity = sys_msg_content.model_identity
|
||||
new_identity = (f'{current_identity}\n{instructions}'
|
||||
if current_identity else instructions)
|
||||
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
|
||||
if reasoning_effort is not None:
|
||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||
REASONING_EFFORT[reasoning_effort])
|
||||
@ -63,6 +88,14 @@ def get_system_message(
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
if python_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||
if container_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(container_description)
|
||||
if not with_custom_tools:
|
||||
channel_config = sys_msg_content.channel_config
|
||||
invalid_channel = "commentary"
|
||||
new_config = ChannelConfig.require_channels(
|
||||
[c for c in channel_config.valid_channels if c != invalid_channel])
|
||||
sys_msg_content = sys_msg_content.with_channel_config(new_config)
|
||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||
return sys_msg
|
||||
|
||||
@ -86,14 +119,17 @@ def get_developer_message(
|
||||
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
||||
) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if instructions is not None:
|
||||
if (instructions is not None
|
||||
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||
if tool.type in ("web_search_preview", "code_interpreter",
|
||||
"container"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
pass
|
||||
|
||||
elif tool.type == "function":
|
||||
function_tools.append(tool)
|
||||
else:
|
||||
@ -136,6 +172,8 @@ def parse_response_input(
|
||||
TextContent(text=text_prefix + c["text"]) for c in content
|
||||
]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
if role == "assistant":
|
||||
msg = msg.with_channel("final")
|
||||
elif response_msg["type"] == "function_call_output":
|
||||
call_id = response_msg["call_id"]
|
||||
call_response: Optional[ResponseFunctionToolCall] = None
|
||||
|
@ -204,7 +204,7 @@ class LLM:
|
||||
|
||||
if "kv_transfer_config" in kwargs and isinstance(
|
||||
kwargs["kv_transfer_config"], dict):
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
raw_config_dict = kwargs["kv_transfer_config"]
|
||||
try:
|
||||
kwargs["kv_transfer_config"] = KVTransferConfig(
|
||||
|
@ -1717,6 +1717,8 @@ async def init_app_state(
|
||||
|
||||
if args.tool_server == "demo":
|
||||
tool_server: Optional[ToolServer] = DemoToolServer()
|
||||
assert isinstance(tool_server, DemoToolServer)
|
||||
await tool_server.init_and_validate()
|
||||
elif args.tool_server:
|
||||
tool_server = MCPToolServer()
|
||||
await tool_server.add_tool_server(args.tool_server)
|
||||
|
@ -134,14 +134,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
"""If specified, will run the OpenAI frontend server in the same process as
|
||||
the model serving engine."""
|
||||
enable_request_id_headers: bool = False
|
||||
"""If specified, API server will add X-Request-Id header to responses.
|
||||
Caution: this hurts performance at high QPS."""
|
||||
"""If specified, API server will add X-Request-Id header to responses."""
|
||||
enable_auto_tool_choice: bool = False
|
||||
"""If specified, exclude tool definitions in prompts when
|
||||
tool_choice='none'."""
|
||||
exclude_tools_when_tool_choice_none: bool = False
|
||||
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
|
||||
to specify which parser to use."""
|
||||
exclude_tools_when_tool_choice_none: bool = False
|
||||
"""If specified, exclude tool definitions in prompts when
|
||||
tool_choice='none'."""
|
||||
tool_call_parser: Optional[str] = None
|
||||
"""Select the tool call parser depending on the model that you're using.
|
||||
This is used to parse the model-generated tool call into OpenAI API format.
|
||||
|
@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_prompt_and_prompt_embeds(cls, data):
|
||||
if data.get("prompt") is None and data.get("prompt_embeds") is None:
|
||||
prompt = data.get("prompt")
|
||||
prompt_embeds = data.get("prompt_embeds")
|
||||
|
||||
prompt_is_empty = (prompt is None
|
||||
or (isinstance(prompt, str) and prompt == ""))
|
||||
embeds_is_empty = (prompt_embeds is None
|
||||
or (isinstance(prompt_embeds, list)
|
||||
and len(prompt_embeds) == 0))
|
||||
|
||||
if prompt_is_empty and embeds_is_empty:
|
||||
raise ValueError(
|
||||
"At least one of `prompt` or `prompt_embeds` must be set.")
|
||||
"Either prompt or prompt_embeds must be provided and non-empty."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
@ -26,12 +26,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
EmbedsPrompt as ServingEngineEmbedsPrompt)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
TextTokensPrompt,
|
||||
clamp_prompt_logprobs,
|
||||
is_text_tokens_prompt)
|
||||
clamp_prompt_logprobs)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
@ -132,12 +128,19 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||
)
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens
|
||||
or 0)
|
||||
|
||||
request_prompts, engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
max_length=max_input_tokens_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo
|
||||
and not request.return_token_ids),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
@ -198,7 +201,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
request_prompts[i],
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@ -249,7 +252,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
@ -273,11 +276,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
request_prompt = request_prompts[i]
|
||||
if is_text_tokens_prompt(request_prompt):
|
||||
final_res.prompt = request_prompt["prompt"]
|
||||
else:
|
||||
final_res.prompt = None
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput],
|
||||
final_res_batch)
|
||||
@ -313,8 +314,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
request_prompts: list[Union[TextTokensPrompt,
|
||||
ServingEngineEmbedsPrompt]],
|
||||
engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
@ -350,14 +350,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
first_iteration = False
|
||||
|
||||
if res.prompt is not None:
|
||||
prompt_text = res.prompt
|
||||
else:
|
||||
request_prompt = request_prompts[prompt_idx]
|
||||
if is_text_tokens_prompt(request_prompt):
|
||||
prompt_text = request_prompt["prompt"]
|
||||
else:
|
||||
prompt_text = None
|
||||
prompt_text = res.prompt
|
||||
if prompt_text is None:
|
||||
engine_prompt = engine_prompts[prompt_idx]
|
||||
prompt_text = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
@ -378,6 +375,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and not has_echoed[i]:
|
||||
assert prompt_token_ids is not None
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
@ -525,6 +524,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo:
|
||||
if request.return_token_ids:
|
||||
prompt_text = ""
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
|
@ -28,7 +28,6 @@ from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
TextTokensPrompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
@ -290,7 +289,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt: EngineTokensPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Optional[Mapping[str, str]],
|
||||
prompt_index: int,
|
||||
@ -303,12 +302,6 @@ class EmbeddingMixin(OpenAIServing):
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Mypy has an existing bug related to inferring the variance
|
||||
# of TypedDicts with `builtins.enumerate`:
|
||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||
engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
@ -375,12 +368,8 @@ class EmbeddingMixin(OpenAIServing):
|
||||
continue
|
||||
|
||||
# Normal processing for short prompts or non-token prompts
|
||||
# Cast engine_prompt to the expected type for mypy
|
||||
engine_prompt_typed = cast(
|
||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
generator = await self._create_single_prompt_generator(
|
||||
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
|
||||
ctx, engine_prompt, pooling_params, trace_headers, i)
|
||||
generators.append(generator)
|
||||
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
@ -9,10 +7,8 @@ import traceback
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypeVar, Union, cast, overload)
|
||||
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@ -64,10 +60,8 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
|
||||
# yapf: enable
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -149,8 +143,7 @@ class RequestProcessingMixin(BaseModel):
|
||||
"""
|
||||
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
||||
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
||||
list[EngineEmbedsPrompt]]] = []
|
||||
engine_prompts: Optional[list[EngineTokensPrompt]] = []
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@ -368,13 +361,6 @@ class OpenAIServing:
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
# Mypy has an existing bug related to inferring the variance of
|
||||
# TypedDicts with `builtins.enumerate`:
|
||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||
engine_prompt = cast(
|
||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
@ -737,170 +723,6 @@ class OpenAIServing:
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
async def _tokenize_prompt_input_or_inputs_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
||||
"""
|
||||
Tokenize/detokenize depending on the input format.
|
||||
|
||||
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
||||
, each input can be a string or array of tokens. Note that each request
|
||||
can pass one or more inputs.
|
||||
"""
|
||||
inputs_embeds = list[EmbedsPrompt]()
|
||||
inputs_text = list[TextTokensPrompt]()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if (truncate_prompt_tokens or 0) < 0:
|
||||
truncate_prompt_tokens = self.max_model_len
|
||||
|
||||
if (isinstance(request, CompletionRequest)
|
||||
and request.prompt_embeds is not None):
|
||||
inputs_embeds.extend(
|
||||
self._load_prompt_embeds(request.prompt_embeds,
|
||||
truncate_prompt_tokens))
|
||||
|
||||
# Empty prompts are okay as long as there are prompt embeddings
|
||||
if input_or_inputs is None or (inputs_embeds
|
||||
and input_or_inputs == ""):
|
||||
return [], inputs_embeds
|
||||
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is False" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
|
||||
# Parse and batch the input prompts
|
||||
batch_inputs = parse_and_batch_prompt(input_or_inputs)
|
||||
|
||||
# Process each input in the batch concurrently
|
||||
tasks = []
|
||||
for prompt_input in batch_inputs:
|
||||
if prompt_input["is_tokens"] is False:
|
||||
assert tokenizer is not None, (
|
||||
"Tokenizer is required for text prompts")
|
||||
task = self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
prompt_input["content"],
|
||||
tokenizer=tokenizer,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
task = self._normalize_prompt_tokens_to_input(
|
||||
request, prompt_input["content"], tokenizer=tokenizer)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all tokenization tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
inputs_text.extend(results)
|
||||
|
||||
return inputs_text, inputs_embeds
|
||||
|
||||
@overload
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: Union[
|
||||
DetokenizeRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
RerankRequest,
|
||||
ClassificationRequest,
|
||||
ScoreRequest,
|
||||
TokenizeCompletionRequest,
|
||||
],
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[
|
||||
list[Union[TextTokensPrompt, EmbedsPrompt]],
|
||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
|
||||
]:
|
||||
...
|
||||
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionLikeRequest,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[
|
||||
Union[list[TextTokensPrompt], list[Union[TextTokensPrompt,
|
||||
EmbedsPrompt]]],
|
||||
Union[
|
||||
list[EngineTokensPrompt],
|
||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
|
||||
],
|
||||
]:
|
||||
if (not isinstance(request, CompletionRequest)
|
||||
and input_or_inputs is None):
|
||||
raise ValueError(
|
||||
"Prompt embeds with non-completion requests is not"
|
||||
" currently supported.")
|
||||
|
||||
(
|
||||
request_prompts_text,
|
||||
request_prompts_embeds,
|
||||
) = await self._tokenize_prompt_input_or_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
engine_prompts_text = [
|
||||
EngineTokensPrompt(
|
||||
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
||||
for request_prompt_text in request_prompts_text
|
||||
]
|
||||
cache_salt = (request.cache_salt if
|
||||
(hasattr(request, "cache_salt")
|
||||
and request.cache_salt is not None) else None)
|
||||
if cache_salt:
|
||||
for prompt_text in engine_prompts_text:
|
||||
prompt_text["cache_salt"] = cache_salt
|
||||
|
||||
# This check is equivalent to simply checking if
|
||||
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
||||
# overloads to the private helper functions to enable this check.
|
||||
# This overload is needed because only TextPrompts are allowed for
|
||||
# non-completion requests and if we don't add the overload here,
|
||||
# everywhere this function is used outside of serving_completion will
|
||||
# need logic asserting that only text prompts are in the request.
|
||||
if (not isinstance(request, CompletionRequest)
|
||||
and input_or_inputs is not None):
|
||||
return request_prompts_text, engine_prompts_text
|
||||
|
||||
engine_prompts_embeds = [
|
||||
EngineEmbedsPrompt(
|
||||
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
||||
for request_prompt_embeds in request_prompts_embeds
|
||||
]
|
||||
if cache_salt:
|
||||
for prompt_embed in engine_prompts_embeds:
|
||||
prompt_embed["cache_salt"] = cache_salt
|
||||
|
||||
request_prompts = request_prompts_embeds + request_prompts_text
|
||||
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
||||
return request_prompts, engine_prompts
|
||||
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: Union[ChatLikeRequest, ResponsesRequest],
|
||||
@ -1073,41 +895,6 @@ class OpenAIServing:
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
|
||||
@staticmethod
|
||||
def _load_prompt_embeds(
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
) -> list[EmbedsPrompt]:
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
return {"prompt_embeds": tensor}
|
||||
|
||||
if prompt_embeds:
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [
|
||||
_load_and_validate_embed(embed) for embed in prompt_embeds
|
||||
]
|
||||
else:
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
|
@ -44,8 +44,9 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext,
|
||||
SimpleContext, StreamingHarmonyContext)
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||
get_system_message, get_user_message, parse_output_message,
|
||||
parse_remaining_state, parse_response_input, render_for_completion)
|
||||
get_system_message, get_user_message, has_custom_tools,
|
||||
parse_output_message, parse_remaining_state, parse_response_input,
|
||||
render_for_completion)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -266,6 +267,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
builtin_tool_list.append("browser")
|
||||
if self.tool_server.has_tool("python"):
|
||||
builtin_tool_list.append("python")
|
||||
if self.tool_server.has_tool("container"):
|
||||
builtin_tool_list.append("container")
|
||||
|
||||
if self.tool_server is not None:
|
||||
available_tools = builtin_tool_list
|
||||
@ -448,7 +451,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
async for _ in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
@ -710,13 +714,21 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# New conversation.
|
||||
reasoning_effort = (request.reasoning.effort
|
||||
if request.reasoning else None)
|
||||
# Temporary: OpenAI types doesn't have container tool
|
||||
# so we used MCP to cover that, up for change
|
||||
tool_types = [tool.type for tool in request.tools]
|
||||
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
|
||||
tool_types.append("container")
|
||||
enable_browser = ("web_search_preview" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("browser"))
|
||||
enable_code_interpreter = ("code_interpreter" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("python"))
|
||||
enable_container = ("container" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("container"))
|
||||
with_custom_tools = has_custom_tools(tool_types)
|
||||
sys_msg = get_system_message(
|
||||
reasoning_effort=reasoning_effort,
|
||||
browser_description=self.tool_server.get_tool_description(
|
||||
@ -725,11 +737,17 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
python_description=self.tool_server.get_tool_description(
|
||||
"python") if enable_code_interpreter
|
||||
and self.tool_server is not None else None,
|
||||
container_description=self.tool_server.get_tool_description(
|
||||
"container")
|
||||
if enable_container and self.tool_server is not None else None,
|
||||
instructions=request.instructions,
|
||||
with_custom_tools=with_custom_tools,
|
||||
)
|
||||
messages.append(sys_msg)
|
||||
dev_msg = get_developer_message(request.instructions,
|
||||
request.tools)
|
||||
messages.append(dev_msg)
|
||||
if with_custom_tools:
|
||||
dev_msg = get_developer_message(
|
||||
instructions=request.instructions, tools=request.tools)
|
||||
messages.append(dev_msg)
|
||||
else:
|
||||
# Continue the previous conversation.
|
||||
# FIXME(woosuk): Currently, request params like reasoning and
|
||||
@ -1613,7 +1631,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
processer = None
|
||||
if self.use_harmony:
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
processer = self._process_harmony_streaming_events
|
||||
else:
|
||||
processer = self._process_simple_streaming_events
|
||||
|
@ -2,12 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -49,37 +53,121 @@ class BaseRenderer(ABC):
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""
|
||||
Convert input prompts into tokenized format for engine processing.
|
||||
|
||||
This is the core method that transforms various input formats into
|
||||
standardized TokensPrompt objects. Implementations should handle
|
||||
tokenization, special token insertion, truncation, and validation
|
||||
according to model requirements.
|
||||
|
||||
Convert text or token inputs into engine-ready TokensPrompt objects.
|
||||
|
||||
This method accepts text or token inputs and produces a
|
||||
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
|
||||
for the engine.
|
||||
|
||||
Args:
|
||||
prompt_or_prompts: Input data in various formats:
|
||||
- str: Single text prompt
|
||||
- list[str]: Batch of text prompts
|
||||
- list[int]: Pre-tokenized sequence
|
||||
- list[list[int]]: Batch of pre-tokenized sequences
|
||||
max_length: Maximum sequence length (endpoint-specific behavior)
|
||||
truncate_prompt_tokens: Truncate to last N tokens
|
||||
(None=no truncation, 0=empty)
|
||||
add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP])
|
||||
to text inputs
|
||||
cache_salt: Optional string to disambiguate cached prompts
|
||||
|
||||
prompt_or_prompts: One of:
|
||||
- ``str``: Single text prompt.
|
||||
- ``list[str]``: Batch of text prompts.
|
||||
- ``list[int]``: Single pre-tokenized sequence.
|
||||
- ``list[list[int]]``: Batch of pre-tokenized sequences.
|
||||
max_length: Maximum allowable total input token length. If provided,
|
||||
token inputs longer than this raise ``ValueError``.
|
||||
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
|
||||
truncation. ``0`` yields an empty list (and skips embeds).
|
||||
``-1`` maps to ``model_config.max_model_len``.
|
||||
add_special_tokens: Whether to add model-specific special tokens
|
||||
during text tokenization.
|
||||
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||
input, detokenize IDs back to text for inclusion in outputs.
|
||||
|
||||
Returns:
|
||||
list[EngineTokensPrompt]: Tokenized prompts ready for engine
|
||||
consumption
|
||||
|
||||
list[EngineTokensPrompt]: Engine-ready token prompts.
|
||||
|
||||
Raises:
|
||||
ValueError: If input format is invalid or length limits exceeded
|
||||
ValueError: If input formats are invalid or length limits exceeded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Convert text/token and/or base64-encoded embeddings inputs into
|
||||
engine-ready prompt objects.
|
||||
|
||||
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
|
||||
provided and non-empty. If both are omitted or empty (e.g., empty
|
||||
string and empty list), a ``ValueError`` is raised.
|
||||
|
||||
Args:
|
||||
prompt_or_prompts: Text or token inputs to include.
|
||||
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
|
||||
torch-saved tensor to be used as prompt embeddings.
|
||||
max_length: Maximum allowable total input token length. If provided,
|
||||
inputs longer than this raise ``ValueError``.
|
||||
truncate_prompt_tokens: Number of tokens/rows to keep from the end
|
||||
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
|
||||
add_special_tokens: Whether to add model-specific special tokens
|
||||
during text tokenization.
|
||||
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||
input, detokenize IDs back to text for inclusion in outputs.
|
||||
|
||||
Returns:
|
||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
Engine-ready prompt objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds``
|
||||
are omitted or empty (decoder prompt cannot be empty), or if
|
||||
length limits are exceeded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load_prompt_embeds(
|
||||
cls,
|
||||
prompt_embeds: Union[bytes, list[bytes]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> list[EngineEmbedsPrompt]:
|
||||
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
|
||||
if cache_salt is not None:
|
||||
embeds_prompt["cache_salt"] = cache_salt
|
||||
return embeds_prompt
|
||||
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
|
||||
else:
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
|
||||
|
||||
class CompletionRenderer(BaseRenderer):
|
||||
|
||||
@ -101,50 +189,110 @@ class CompletionRenderer(BaseRenderer):
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""Implementation of prompt rendering for completion-style requests.
|
||||
|
||||
Uses async tokenizer pooling for improved performance. See base class
|
||||
for detailed parameter documentation.
|
||||
"""
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = self.model_config.max_model_len
|
||||
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||
f"cannot be greater than max_length ({max_length}). "
|
||||
f"Please select a smaller truncation size.")
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
truncate_prompt_tokens, max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
# Parse and batch the input prompts
|
||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||
|
||||
rendered_prompts: list[EngineTokensPrompt] = []
|
||||
tokenize_tasks = []
|
||||
tasks = []
|
||||
for prompt_input in batch_inputs:
|
||||
if prompt_input["is_tokens"] is True:
|
||||
# Token input
|
||||
token_ids = self._maybe_apply_truncation(
|
||||
prompt_input["content"], truncate_prompt_tokens)
|
||||
rendered_prompts.append(
|
||||
self._create_tokens_prompt(token_ids, max_length,
|
||||
cache_salt))
|
||||
detokenize_task = asyncio.create_task(
|
||||
# Note: detokenization is needed when echo is enabled,
|
||||
# where the input token IDs are decoded back to text.
|
||||
self._maybe_detokenize(prompt_input["content"], max_length,
|
||||
truncate_prompt_tokens, cache_salt,
|
||||
needs_detokenization))
|
||||
tasks.append(detokenize_task)
|
||||
else:
|
||||
# Text input
|
||||
tokenize_task = asyncio.create_task(
|
||||
self._tokenize(prompt_input["content"], max_length,
|
||||
truncate_prompt_tokens, add_special_tokens,
|
||||
cache_salt))
|
||||
tokenize_tasks.append(tokenize_task)
|
||||
tasks.append(tokenize_task)
|
||||
|
||||
# Wait for all text tokenization to finish
|
||||
if tokenize_tasks:
|
||||
tokenized_text_prompts = await asyncio.gather(*tokenize_tasks)
|
||||
rendered_prompts.extend(tokenized_text_prompts)
|
||||
if tasks:
|
||||
tokenized_text_prompts = await asyncio.gather(*tasks)
|
||||
return tokenized_text_prompts
|
||||
|
||||
return rendered_prompts
|
||||
return []
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
truncate_prompt_tokens, max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
|
||||
|
||||
if prompt_embeds is not None:
|
||||
rendered.extend(
|
||||
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
|
||||
cache_salt))
|
||||
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||
return rendered
|
||||
|
||||
token_prompts = await self.render_prompt(
|
||||
prompt_or_prompts=prompt_or_prompts,
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
cache_salt=cache_salt,
|
||||
needs_detokenization=needs_detokenization,
|
||||
)
|
||||
rendered.extend(token_prompts)
|
||||
|
||||
return rendered
|
||||
|
||||
def _validate_and_normalize_truncate_tokens(
|
||||
self,
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
max_length: Optional[int],
|
||||
) -> Optional[int]:
|
||||
"""Validate and normalize truncate_prompt_tokens parameter."""
|
||||
if truncate_prompt_tokens is None:
|
||||
return None
|
||||
|
||||
if truncate_prompt_tokens == 0:
|
||||
return 0
|
||||
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = self.model_config.max_model_len
|
||||
|
||||
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||
f"cannot be greater than max_length ({max_length}). "
|
||||
f"Please select a smaller truncation size.")
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
def _maybe_apply_truncation(
|
||||
self, token_ids: list[int],
|
||||
@ -186,7 +334,29 @@ class CompletionRenderer(BaseRenderer):
|
||||
max_length=truncate_prompt_tokens)
|
||||
|
||||
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
||||
cache_salt)
|
||||
cache_salt, text)
|
||||
|
||||
async def _maybe_detokenize(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
max_length: Optional[int],
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
cache_salt: Optional[str],
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> EngineTokensPrompt:
|
||||
"""Optionally detokenize token IDs and build a tokens prompt."""
|
||||
token_ids = self._maybe_apply_truncation(token_ids,
|
||||
truncate_prompt_tokens)
|
||||
|
||||
prompt = None
|
||||
if needs_detokenization is True:
|
||||
async_tokenizer = self._get_async_tokenizer()
|
||||
prompt = await async_tokenizer.decode(token_ids)
|
||||
|
||||
return self._create_tokens_prompt(token_ids=token_ids,
|
||||
max_length=max_length,
|
||||
cache_salt=cache_salt,
|
||||
prompt=prompt)
|
||||
|
||||
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
"""Get or create async tokenizer using shared pool."""
|
||||
@ -210,6 +380,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
token_ids: list[int],
|
||||
max_length: Optional[int] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
) -> EngineTokensPrompt:
|
||||
"""Create validated EngineTokensPrompt."""
|
||||
if max_length is not None and len(token_ids) > max_length:
|
||||
@ -221,4 +392,6 @@ class CompletionRenderer(BaseRenderer):
|
||||
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
|
||||
if cache_salt is not None:
|
||||
tokens_prompt["cache_salt"] = cache_salt
|
||||
return tokens_prompt
|
||||
if prompt is not None:
|
||||
tokens_prompt["prompt"] = prompt
|
||||
return tokens_prompt
|
@ -4,6 +4,8 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai_harmony import Author, Message, Role, TextContent
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -99,6 +101,28 @@ class HarmonyPythonTool(Tool):
|
||||
return
|
||||
|
||||
self.python_tool = PythonTool()
|
||||
|
||||
async def validate(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
message = Message(
|
||||
author=Author(role=Role.ASSISTANT),
|
||||
content=[TextContent(text="print('Hello, world!')")],
|
||||
channel="analysis",
|
||||
recipient="python",
|
||||
content_type="code",
|
||||
)
|
||||
msgs = []
|
||||
async for msg in self.python_tool.process(message):
|
||||
msgs.append(msg)
|
||||
assert msgs[0].content[0].text == "Hello, world!\n"
|
||||
except Exception as e:
|
||||
self.enabled = False
|
||||
logger.warning_once(
|
||||
"Code interpreter tool failed to initialize (%s), code "
|
||||
"interpreter is disabled", e)
|
||||
return
|
||||
logger.info_once("Code interpreter tool initialized")
|
||||
|
||||
async def get_result(self, context: "ConversationContext") -> Any:
|
||||
|
@ -86,7 +86,8 @@ class ToolServer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]:
|
||||
def new_session(self, tool_name: str,
|
||||
session_id: str) -> AbstractAsyncContextManager[Any]:
|
||||
"""
|
||||
Create a session for the tool.
|
||||
"""
|
||||
@ -124,7 +125,8 @@ class MCPToolServer(ToolServer):
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema)
|
||||
for tool in list_tools_response.tools
|
||||
])
|
||||
],
|
||||
)
|
||||
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
||||
if tool_from_mcp.name not in self.urls:
|
||||
self.urls[tool_from_mcp.name] = url
|
||||
@ -142,14 +144,16 @@ class MCPToolServer(ToolServer):
|
||||
return self.harmony_tool_descriptions.get(tool_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str):
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
url = self.urls.get(tool_name)
|
||||
headers = {"x-session-id": session_id}
|
||||
if not url:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
async with sse_client(url=url) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
async with sse_client(url=url,
|
||||
headers=headers) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
@ -158,10 +162,13 @@ class DemoToolServer(ToolServer):
|
||||
|
||||
def __init__(self):
|
||||
self.tools: dict[str, Tool] = {}
|
||||
|
||||
async def init_and_validate(self):
|
||||
browser_tool = HarmonyBrowserTool()
|
||||
python_tool = HarmonyPythonTool()
|
||||
await python_tool.validate()
|
||||
if browser_tool.enabled:
|
||||
self.tools["browser"] = browser_tool
|
||||
python_tool = HarmonyPythonTool()
|
||||
if python_tool.enabled:
|
||||
self.tools["python"] = python_tool
|
||||
logger.info("DemoToolServer initialized with tools: %s",
|
||||
@ -182,7 +189,7 @@ class DemoToolServer(ToolServer):
|
||||
raise ValueError(f"Unknown tool {tool_name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str):
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
if tool_name not in self.tools:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
yield self.tools[tool_name]
|
||||
|
17
vllm/envs.py
17
vllm/envs.py
@ -168,7 +168,10 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1201,9 +1204,23 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
|
||||
# Allows vllm use container tool
|
||||
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
|
||||
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
|
||||
|
||||
# Allows harmony instructions to be injected on system messages
|
||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
|
||||
lambda: bool(
|
||||
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))),
|
||||
|
||||
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||
|
||||
# Represent block hashes in KV cache events as 64-bit integers instead of
|
||||
# raw bytes. Defaults to True for backward compatibility.
|
||||
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
|
||||
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
@ -52,6 +52,9 @@ class TokensPrompt(TypedDict):
|
||||
prompt_token_ids: list[int]
|
||||
"""A list of token IDs to pass to the model."""
|
||||
|
||||
prompt: NotRequired[str]
|
||||
"""The prompt text corresponding to the token IDs, if available."""
|
||||
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""A list of token type IDs to pass to the cross encoder model."""
|
||||
|
||||
|
8
vllm/model_executor/layers/fla/__init__.py
Normal file
8
vllm/model_executor/layers/fla/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
17
vllm/model_executor/layers/fla/ops/__init__.py
Normal file
17
vllm/model_executor/layers/fla/ops/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .layernorm_guard import RMSNormGated
|
||||
|
||||
__all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
]
|
225
vllm/model_executor/layers/fla/ops/chunk.py
Normal file
225
vllm/model_executor/layers/fla/ops/chunk.py
Normal file
@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=torch.float32)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
return g, o, A, final_state, w, h, v_new
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
head_first (Optional[bool]):
|
||||
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
||||
Default: `False`.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
return o, final_state
|
289
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Normal file
289
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Normal file
@ -0,0 +1,289 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .op import exp, safe_exp
|
||||
from .utils import is_nvidia_hopper, use_cuda_graph
|
||||
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
||||
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
||||
'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4] for num_stages in [2, 3, 4] for BV in [32, 64]
|
||||
],
|
||||
key=['H', 'K', 'V', 'BT', 'USE_G'],
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# [BK, BV]
|
||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
|
||||
# calculate offset
|
||||
h += (boh * H + i_h) * K * V
|
||||
v += (bos * H + i_h) * V
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
w += (bos * H + i_h) * K
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new += (bos * H + i_h) * V
|
||||
stride_v = H * V
|
||||
stride_h = H * K * V
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
if USE_INITIAL_STATE:
|
||||
h0 = h0 + i_nh * K * V
|
||||
if STORE_FINAL_STATE:
|
||||
ht = ht + i_nh * K * V
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV),
|
||||
(1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h2,
|
||||
b_h2.to(p_h2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h3,
|
||||
b_h3.to(p_h3.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h4,
|
||||
b_h4.to(p_h4.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV),
|
||||
(1, 0)) if SAVE_NEW_VALUE else None
|
||||
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_v_new,
|
||||
b_v_new.to(p_v_new.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
|
||||
b_g_last = exp(b_g_last)
|
||||
b_h1 = b_h1 * b_g_last
|
||||
if K > 64:
|
||||
b_h2 = b_h2 * b_g_last
|
||||
if K > 128:
|
||||
b_h3 = b_h3 * b_g_last
|
||||
if K > 192:
|
||||
b_h4 = b_h4 * b_g_last
|
||||
b_v_new = b_v_new.to(k.dtype.element_ty)
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h1 += tl.dot(b_k, b_v_new)
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h2 += tl.dot(b_k, b_v_new)
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h3 += tl.dot(b_k, b_v_new)
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h4 += tl.dot(b_k, b_v_new)
|
||||
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV),
|
||||
(1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
tl.store(p_ht,
|
||||
b_h2.to(p_ht.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
tl.store(p_ht,
|
||||
b_h3.to(p_ht.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
tl.store(p_ht,
|
||||
b_h4.to(p_ht.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(
|
||||
chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = k.new_empty(
|
||||
N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), N * H)
|
||||
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT)
|
||||
return h, v_new, final_state
|
176
vllm/model_executor/layers/fla/ops/chunk_o.py
Normal file
176
vllm/model_executor/layers/fla/ops/chunk_o.py
Normal file
@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp, safe_exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({
|
||||
'BK': BK,
|
||||
'BV': BV
|
||||
},
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages) for BK in BKV_LIST
|
||||
for BV in BKV_LIST for num_warps in NUM_WARPS
|
||||
for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=['H', 'K', 'V', 'BT'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
NT = tl.cdiv(T, BT)
|
||||
i_tg = i_b * NT + i_t
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
# offset calculation
|
||||
q += (bos * Hg + i_h // (H // Hg)) * K
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
|
||||
(BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
|
||||
(BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_o = b_o * exp(b_g)[:, None]
|
||||
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_A = o_i[:, None] >= o_i[None, :]
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by triton v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
if FLA_GDN_FIX_BT:
|
||||
BT = 64
|
||||
else:
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
)
|
||||
return o
|
138
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Normal file
138
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Normal file
@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import safe_exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
'USE_G': lambda args: args['g_cumsum'] is not None
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
|
||||
for BK in [32, 64, 128] for num_warps in [2, 4, 8]
|
||||
for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=['H', 'K', 'BT', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta,
|
||||
g_cumsum,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_t = tl.arange(0, BT)
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
|
||||
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
|
||||
(1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * safe_exp(b_g_diff)
|
||||
|
||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g_cumsum (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`.
|
||||
Default: None
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
|
||||
B, T, Hg, K = k.shape
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g_cumsum,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
BT=BT,
|
||||
)
|
||||
return A
|
226
vllm/model_executor/layers/fla/ops/cumsum.py
Normal file
226
vllm/model_executor/layers/fla/ops/cumsum.py
Normal file
@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import check_shared_mem, input_guard
|
||||
|
||||
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]
|
||||
],
|
||||
key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
# [BT]
|
||||
b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32)
|
||||
b_o = tl.cumsum(b_s, axis=0)
|
||||
if REVERSE:
|
||||
b_z = tl.sum(b_s, axis=0)
|
||||
b_o = -b_o + b_z[None] + b_s
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, ))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST
|
||||
for num_warps in [2, 4, 8]
|
||||
],
|
||||
key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_vector_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
if REVERSE:
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
||||
else:
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B * H)
|
||||
chunk_local_cumsum_scalar_kernel[grid](g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum_vector(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T, S = g.shape
|
||||
else:
|
||||
B, T, H, S = g.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
|
||||
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_local_cumsum_vector_kernel[grid](g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
S=S,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse)
|
||||
return g
|
||||
|
||||
|
||||
@input_guard
|
||||
def chunk_local_cumsum(g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs) -> torch.Tensor:
|
||||
if not head_first and g.shape[1] < g.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
assert g.shape[
|
||||
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
elif len(g.shape) == 4:
|
||||
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape {g.shape}. "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise")
|
366
vllm/model_executor/layers/fla/ops/fused_recurrent.py
Normal file
366
vllm/model_executor/layers/fla/ops/fused_recurrent.py
Normal file
@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .op import exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_INITIAL_STATE':
|
||||
lambda args: args['h0'] is not None,
|
||||
'IS_VARLEN':
|
||||
lambda args: args['cu_seqlens'] is not None,
|
||||
"IS_CONTINUOUS_BATCHING":
|
||||
lambda args: args['ssm_state_indices'] is not None,
|
||||
"IS_SPEC_DECODING":
|
||||
lambda args: args['num_accepted_tokens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['N', 'T'])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.constexpr, # num of sequences
|
||||
T: tl.constexpr, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
IS_BETA_HEADWISE: tl.
|
||||
constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv
|
||||
p_g = g + bos * HV + i_hv
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
b_h *= exp(b_g)
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_final_state_token
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
p_g += HV
|
||||
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
|
||||
if ssm_state_indices is None:
|
||||
stride_indices_seq, stride_indices_tok = 1, 1
|
||||
elif ssm_state_indices.ndim == 1:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
|
||||
else:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
o=o,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
scale=scale,
|
||||
N=N,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
stride_indices_tok=stride_indices_tok,
|
||||
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
INPLACE_FINAL_STATE=inplace_final_state,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o, final_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
o, final_state = fused_recurrent_gated_delta_rule_fwd(
|
||||
q=q.contiguous(),
|
||||
k=k.contiguous(),
|
||||
v=v.contiguous(),
|
||||
g=g.contiguous(),
|
||||
beta=beta.contiguous(),
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
inplace_final_state=inplace_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
|
||||
return o, final_state
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, HV, V]`.
|
||||
GVA is applied if `HV > H`.
|
||||
g (torch.Tensor):
|
||||
g (decays) of shape `[B, T, HV]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, HV]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
Indices to map the input sequences to the initial/final states.
|
||||
num_accepted_tokens (Optional[torch.Tensor]):
|
||||
Number of accepted tokens for each sequence during decoding.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
39
vllm/model_executor/layers/fla/ops/index.py
Normal file
39
vllm/model_executor/layers/fla/ops/index.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
|
||||
1).to(cu_seqlens)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([
|
||||
cu_seqlens.new_tensor([0]),
|
||||
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
|
||||
]).cumsum(-1)
|
143
vllm/model_executor/layers/fla/ops/l2norm.py
Normal file
143
vllm/model_executor/layers/fla/ops/l2norm.py
Normal file
@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
y,
|
||||
D,
|
||||
BD: tl.constexpr,
|
||||
eps,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
x += i_t * D
|
||||
y += i_t * D
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BD)
|
||||
mask = cols < D
|
||||
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=0)
|
||||
b_rstd = 1 / tl.sqrt(b_var + eps)
|
||||
# tl.store(Rstd + i_t, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
b_y = b_x * b_rstd
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB,
|
||||
T,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=1)
|
||||
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
|
||||
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * MBLOCK
|
||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta['BT']), )
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T, )](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
337
vllm/model_executor/layers/fla/ops/layernorm_guard.py
Normal file
337
vllm/model_executor/layers/fla/ops/layernorm_guard.py
Normal file
@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Tri Dao
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
# ruff: noqa: E501
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
def rms_norm_ref(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True):
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
||||
weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
||||
eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def layer_norm_fwd(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float,
|
||||
z: torch.Tensor = None,
|
||||
out: torch.Tensor = None,
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
layer_norm_fwd_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, is_rms_norm)
|
||||
|
||||
|
||||
def rmsnorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, True)
|
||||
|
||||
|
||||
class LayerNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return layernorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
|
||||
|
||||
class RMSNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return rmsnorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate)
|
44
vllm/model_executor/layers/fla/ops/op.py
Normal file
44
vllm/model_executor/layers/fla/ops/op.py
Normal file
@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
||||
div = tldevice.fast_dividef
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def div_normal(x, y):
|
||||
return x / y
|
||||
|
||||
div = div_normal
|
||||
exp = tl.exp
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def safe_exp(x):
|
||||
return exp(tl.where(x <= 0, x, float('-inf')))
|
||||
|
||||
|
||||
if not hasattr(tl, 'gather'):
|
||||
|
||||
@triton.jit
|
||||
def gather(src, index, axis, _builder=None):
|
||||
# This is a fallback implementation when tl.gather is not supported
|
||||
# In order to pass triton compiler, there is no actual gather operation
|
||||
return src
|
||||
else:
|
||||
gather = tl.gather
|
365
vllm/model_executor/layers/fla/ops/solve_tril.py
Normal file
365
vllm/model_executor/layers/fla/ops/solve_tril.py
Normal file
@ -0,0 +1,365 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=['BT'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ad = Ad + (bos * H + i_h) * 16
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16),
|
||||
(1, 0))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A = -tl.where(
|
||||
tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
for i in range(1, min(16, T - i_t * 16)):
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
mask = o_i == i
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
b_A += o_i[:, None] == o_i[None, :]
|
||||
tl.store(p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=['H', 'BT', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A += (bos * H + i_h) * 32
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 32
|
||||
|
||||
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'),
|
||||
Ai_11,
|
||||
input_precision='ieee')
|
||||
tl.store(p_Ai_11,
|
||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_22,
|
||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_21,
|
||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4, 8] for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=['H', 'BT', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A += (bos * H + i_h) * 64
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 64
|
||||
|
||||
p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0),
|
||||
(16, 16), (1, 0))
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'),
|
||||
Ai_11,
|
||||
input_precision='ieee')
|
||||
Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'),
|
||||
Ai_22,
|
||||
input_precision='ieee')
|
||||
Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'),
|
||||
Ai_33,
|
||||
input_precision='ieee')
|
||||
|
||||
Ai_31 = -tl.dot(Ai_33,
|
||||
tl.dot(A_31, Ai_11, input_precision='ieee') +
|
||||
tl.dot(A_32, Ai_21, input_precision='ieee'),
|
||||
input_precision='ieee')
|
||||
Ai_42 = -tl.dot(Ai_44,
|
||||
tl.dot(A_42, Ai_22, input_precision='ieee') +
|
||||
tl.dot(A_43, Ai_32, input_precision='ieee'),
|
||||
input_precision='ieee')
|
||||
Ai_41 = -tl.dot(Ai_44,
|
||||
tl.dot(A_41, Ai_11, input_precision='ieee') +
|
||||
tl.dot(A_42, Ai_21, input_precision='ieee') +
|
||||
tl.dot(A_43, Ai_31, input_precision='ieee'),
|
||||
input_precision='ieee')
|
||||
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32),
|
||||
(16, 16), (1, 0))
|
||||
tl.store(p_Ai_11,
|
||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_22,
|
||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_33,
|
||||
Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_44,
|
||||
Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_21,
|
||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_31,
|
||||
Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_32,
|
||||
Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_41,
|
||||
Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_42,
|
||||
Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_43,
|
||||
Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
fill_zeros = tl.zeros((16, 16), dtype=tl.float32)
|
||||
p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48),
|
||||
(16, 16), (1, 0))
|
||||
tl.store(p_Ai_12,
|
||||
fill_zeros.to(p_Ai_12.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_13,
|
||||
fill_zeros.to(p_Ai_13.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_14,
|
||||
fill_zeros.to(p_Ai_14.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_23,
|
||||
fill_zeros.to(p_Ai_23.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_24,
|
||||
fill_zeros.to(p_Ai_24.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_34,
|
||||
fill_zeros.to(p_Ai_34.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
@input_guard
|
||||
def solve_tril(A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse of the lower triangular matrix
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, K]
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
"""
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
Ad = torch.empty(B,
|
||||
T,
|
||||
H,
|
||||
16,
|
||||
device=A.device,
|
||||
dtype=torch.float if BT != 16 else output_dtype)
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, 16) if cu_seqlens is not None else None
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
)
|
||||
if BT == 16:
|
||||
return Ad
|
||||
|
||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
||||
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
merge_fn[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
Ai=Ai,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
)
|
||||
return Ai
|
180
vllm/model_executor/layers/fla/ops/utils.py
Normal file
180
vllm/model_executor/layers/fla/ops/utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
||||
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
|
||||
|
||||
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
||||
|
||||
|
||||
def tensor_cache(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
|
||||
cache_size = 4
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache_entries, cache_size
|
||||
for i, entry in enumerate(cache_entries):
|
||||
last_args, last_kwargs, last_result = entry
|
||||
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \
|
||||
and all(a is b for a, b in zip(args, last_args)) \
|
||||
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
|
||||
cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [
|
||||
(args, kwargs, last_result)
|
||||
]
|
||||
return last_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if len(cache_entries) >= cache_size:
|
||||
cache_entries = cache_entries[1:]
|
||||
cache_entries.append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def input_guard(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else
|
||||
i.contiguous() for i in args)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = torch.cuda.device(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_available_device() -> str:
|
||||
try:
|
||||
return triton.runtime.driver.active.get_current_target().backend
|
||||
except BaseException:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
|
||||
device = get_available_device()
|
||||
mapping = {
|
||||
"cuda": "nvidia",
|
||||
"hip": "amd",
|
||||
"xpu": "intel",
|
||||
}
|
||||
# return the mapped value, or the original if not found
|
||||
return mapping.get(device, device)
|
||||
|
||||
|
||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
|
||||
device_torch_lib = getattr(torch, device)
|
||||
device_platform = _check_platform()
|
||||
|
||||
is_amd = (device_platform == 'amd')
|
||||
is_intel = (device_platform == 'intel')
|
||||
is_nvidia = (device_platform == 'nvidia')
|
||||
is_intel_alchemist = (is_intel
|
||||
and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
|
||||
is_nvidia_hopper = (is_nvidia
|
||||
and ('NVIDIA H' in torch.cuda.get_device_name(0)
|
||||
or torch.cuda.get_device_capability()[0] >= 9))
|
||||
use_cuda_graph = (is_nvidia
|
||||
and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
|
||||
|
||||
|
||||
def get_all_max_shared_mem():
|
||||
try:
|
||||
return [
|
||||
triton.runtime.driver.active.utils.get_device_properties(i)
|
||||
['max_shared_mem'] for i in range(device_torch_lib.device_count())
|
||||
]
|
||||
except BaseException:
|
||||
return [-1]
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ADA = 101376 # RTX 4090
|
||||
AMPERE = 166912 # A100
|
||||
HOPPER = 232448 # H100
|
||||
DEFAULT = 102400 # Default
|
||||
|
||||
@classmethod
|
||||
def get_shared_memory(cls, arch: str) -> int:
|
||||
try:
|
||||
return cls[arch.upper()].value
|
||||
except KeyError:
|
||||
return cls.DEFAULT.value
|
||||
|
||||
|
||||
@functools.cache
|
||||
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||
try:
|
||||
device_shared_mem_list = get_all_max_shared_mem()
|
||||
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||
except Exception:
|
||||
return False
|
114
vllm/model_executor/layers/fla/ops/wy_fast.py
Normal file
114
vllm/model_executor/layers/fla/ops/wy_fast.py
Normal file
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4, 8] for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, Hg: tl.constexpr,
|
||||
K: tl.constexpr, V: tl.constexpr,
|
||||
BT: tl.constexpr, BK: tl.constexpr,
|
||||
BV: tl.constexpr, IS_VARLEN: tl.constexpr):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_g = tl.exp(tl.load(p_g, boundary_check=(0, )))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
|
||||
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
|
||||
(1, 0))
|
||||
p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1),
|
||||
(i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_w_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
return w, u
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user