mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 06:03:50 +08:00
[V0 deprecation] Deprecate V0 Neuron backend (#21159)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -149,19 +149,3 @@ steps:
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build Neuron release image"
|
||||
key: block-neuron-release-image-build
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build and publish Neuron release image"
|
||||
depends_on: block-neuron-release-image-build
|
||||
agents:
|
||||
queue: neuron-postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
@ -1,64 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script build the Neuron docker image and run the API server inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -e
|
||||
set -v
|
||||
|
||||
image_name="neuron/vllm-ci"
|
||||
container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||
|
||||
HF_CACHE="$(realpath ~)/huggingface"
|
||||
mkdir -p "${HF_CACHE}"
|
||||
HF_MOUNT="/root/.cache/huggingface"
|
||||
HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN)
|
||||
|
||||
NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache"
|
||||
mkdir -p "${NEURON_COMPILE_CACHE_URL}"
|
||||
NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache"
|
||||
|
||||
# Try building the docker image
|
||||
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws
|
||||
|
||||
# prune old image and containers to save disk space, and only once a day
|
||||
# by using a timestamp file in tmp.
|
||||
if [ -f /tmp/neuron-docker-build-timestamp ]; then
|
||||
last_build=$(cat /tmp/neuron-docker-build-timestamp)
|
||||
current_time=$(date +%s)
|
||||
if [ $((current_time - last_build)) -gt 86400 ]; then
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune -f
|
||||
echo "$current_time" > /tmp/neuron-docker-build-timestamp
|
||||
fi
|
||||
else
|
||||
date "+%s" > /tmp/neuron-docker-build-timestamp
|
||||
fi
|
||||
|
||||
docker build -t "${image_name}" -f docker/Dockerfile.neuron .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
docker image rm -f "${image_name}" || true;
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
|
||||
# Run the image
|
||||
docker run --rm -it --device=/dev/neuron0 --network bridge \
|
||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||
-e "HF_HOME=${HF_MOUNT}" \
|
||||
-e "HF_TOKEN=${HF_TOKEN}" \
|
||||
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "
|
||||
set -e; # Exit on first error
|
||||
python3 /workspace/vllm/examples/offline_inference/neuron.py;
|
||||
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
|
||||
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
|
||||
echo \"Running test file: \$f\";
|
||||
python3 -m pytest \$f -v --capture=tee-sys;
|
||||
done
|
||||
"
|
@ -2,7 +2,6 @@ include LICENSE
|
||||
include requirements/common.txt
|
||||
include requirements/cuda.txt
|
||||
include requirements/rocm.txt
|
||||
include requirements/neuron.txt
|
||||
include requirements/cpu.txt
|
||||
include CMakeLists.txt
|
||||
|
||||
|
@ -1,56 +0,0 @@
|
||||
# default base image
|
||||
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
RUN echo "Base image is $BASE_IMAGE"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
ffmpeg libsm6 libxext6 libgl1
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /workspace
|
||||
ARG APP_MOUNT=/workspace
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity
|
||||
RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install pytest
|
||||
|
||||
# uninstall transformers-neuronx package explicitly to avoid version conflict
|
||||
RUN python3 -m pip uninstall -y transformers-neuronx
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
RUN python3 -m pip install -U \
|
||||
'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
-r requirements/neuron.txt
|
||||
|
||||
ENV VLLM_TARGET_DEVICE neuron
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
pip install --no-build-isolation -v -e .
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
# install transformers-neuronx package as an optional dependencies (for V0)
|
||||
# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict
|
||||
RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps
|
||||
|
||||
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
|
||||
|
||||
# overwrite entrypoint to run bash script
|
||||
RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py
|
||||
|
||||
CMD ["/bin/bash"]
|
@ -1,49 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
max_num_seqs=8,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=1024,
|
||||
block_size=1024,
|
||||
# ruff: noqa: E501
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,61 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to run offline inference with an EAGLE speculative
|
||||
decoding model on neuron. To use EAGLE speculative decoding, you must use
|
||||
a draft model that is specifically fine-tuned for EAGLE speculation.
|
||||
Additionally, to use EAGLE with NxD Inference, the draft model must include
|
||||
the LM head weights from the target model. These weights are shared between
|
||||
the draft and target model.
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"What is annapurna labs?",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct",
|
||||
speculative_config={
|
||||
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_num_seqs=4,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in neuronx-distributed-inference.
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
tensor_parallel_size=32,
|
||||
override_neuron_config={
|
||||
"enable_eagle_speculation": True,
|
||||
"enable_fused_speculation": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, \n\n\n Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,63 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
|
||||
# Quantizes neuron model weight to int8 ,
|
||||
# The default config for quantization is int8 dtype.
|
||||
os.environ["NEURON_QUANT_DTYPE"] = "s8"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
max_num_seqs=8,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
# ruff: noqa: E501
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
quantization="neuron_quant",
|
||||
override_neuron_config={
|
||||
"cast_logits_dtype": "bfloat16",
|
||||
},
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,110 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import requests
|
||||
import torch
|
||||
from neuronx_distributed_inference.models.mllama.utils import add_instruct
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM, SamplingParams, TextPrompt
|
||||
|
||||
|
||||
def get_image(image_url):
|
||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
# Model Inputs
|
||||
PROMPTS = [
|
||||
"What is in this image? Tell me a story",
|
||||
"What is the recipe of mayonnaise in two sentences?",
|
||||
"Describe this image",
|
||||
"What is the capital of Italy famous for?",
|
||||
]
|
||||
IMAGES = [
|
||||
get_image(
|
||||
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
|
||||
),
|
||||
None,
|
||||
get_image(
|
||||
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
|
||||
),
|
||||
None,
|
||||
]
|
||||
SAMPLING_PARAMS = [
|
||||
dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16)
|
||||
for _ in range(len(PROMPTS))
|
||||
]
|
||||
|
||||
|
||||
def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params):
|
||||
# Prepare all inputs for mllama generation, including:
|
||||
# 1. put text prompt into instruct chat template
|
||||
# 2. compose single text and single image prompt into Vllm's prompt class
|
||||
# 3. prepare sampling parameters
|
||||
input_image = single_image
|
||||
has_image = torch.tensor([1])
|
||||
if isinstance(single_image, torch.Tensor) and single_image.numel() == 0:
|
||||
has_image = torch.tensor([0])
|
||||
|
||||
instruct_prompt = add_instruct(prompt, has_image)
|
||||
inputs = TextPrompt(prompt=instruct_prompt)
|
||||
|
||||
if input_image is not None:
|
||||
inputs["multi_modal_data"] = {"image": input_image}
|
||||
|
||||
sampling_params = SamplingParams(**sampling_params)
|
||||
return inputs, sampling_params
|
||||
|
||||
|
||||
def print_outputs(outputs):
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
def main():
|
||||
assert (
|
||||
len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS)
|
||||
), f"""Text, image prompts and sampling parameters should have the
|
||||
same batch size; but got {len(PROMPTS)}, {len(IMAGES)},
|
||||
and {len(SAMPLING_PARAMS)}"""
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
max_num_seqs=1,
|
||||
max_model_len=4096,
|
||||
block_size=4096,
|
||||
device="neuron",
|
||||
tensor_parallel_size=32,
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled": False,
|
||||
"skip_warmup": True,
|
||||
"save_sharded_checkpoint": True,
|
||||
"on_device_sampling_config": {
|
||||
"global_topk": 1,
|
||||
"dynamic": False,
|
||||
"deterministic": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
batched_inputs = []
|
||||
batched_sample_params = []
|
||||
for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS):
|
||||
inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params)
|
||||
# test batch-size = 1
|
||||
outputs = llm.generate(inputs, sampling_params)
|
||||
print_outputs(outputs)
|
||||
batched_inputs.append(inputs)
|
||||
batched_sample_params.append(sampling_params)
|
||||
|
||||
# test batch-size = 4
|
||||
outputs = llm.generate(batched_inputs, batched_sample_params)
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,64 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to run offline inference with a speculative
|
||||
decoding model on neuron.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, I am a language model and I can help",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
|
||||
|
||||
def config_buckets():
|
||||
"""Configure context length and token gen buckets."""
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
|
||||
|
||||
|
||||
def initialize_llm():
|
||||
"""Create an LLM with speculative decoding."""
|
||||
return LLM(
|
||||
model="openlm-research/open_llama_7b",
|
||||
speculative_config={
|
||||
"model": "openlm-research/open_llama_3b",
|
||||
"num_speculative_tokens": 4,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_num_seqs=4,
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
device="neuron",
|
||||
tensor_parallel_size=32,
|
||||
)
|
||||
|
||||
|
||||
def process_requests(llm: LLM, sampling_params: SamplingParams):
|
||||
"""Generate texts from prompts and print them."""
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function that sets up the llm and processes prompts."""
|
||||
config_buckets()
|
||||
llm = initialize_llm()
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(max_tokens=100, top_k=1)
|
||||
process_requests(llm, sampling_params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,9 +0,0 @@
|
||||
# Common dependencies
|
||||
-r common.txt
|
||||
|
||||
# Dependencies for Neuron devices
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
torch-neuronx >= 2.5.0
|
||||
neuronx-cc>=2.0.0a0
|
||||
torchvision # Required for Llama3.2 multimodal image preprocessing
|
36
setup.py
36
setup.py
@ -413,8 +413,7 @@ def _no_device() -> bool:
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
has_cuda = torch.version.cuda is not None
|
||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
|
||||
and not (_is_neuron() or _is_tpu()))
|
||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu())
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
@ -422,10 +421,6 @@ def _is_hip() -> bool:
|
||||
or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
|
||||
|
||||
|
||||
def _is_neuron() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "neuron"
|
||||
|
||||
|
||||
def _is_tpu() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "tpu"
|
||||
|
||||
@ -470,25 +465,6 @@ def get_rocm_version():
|
||||
return None
|
||||
|
||||
|
||||
def get_neuronxcc_version():
|
||||
import sysconfig
|
||||
site_dir = sysconfig.get_paths()["purelib"]
|
||||
version_file = os.path.join(site_dir, "neuronxcc", "version",
|
||||
"__init__.py")
|
||||
|
||||
# Check if the command was executed successfully
|
||||
with open(version_file) as fp:
|
||||
content = fp.read()
|
||||
|
||||
# Extract the version using a regular expression
|
||||
match = re.search(r"__version__ = '(\S+)'", content)
|
||||
if match:
|
||||
# Return the version string
|
||||
return match.group(1)
|
||||
else:
|
||||
raise RuntimeError("Could not find Neuron version in the output")
|
||||
|
||||
|
||||
def get_nvcc_cuda_version() -> Version:
|
||||
"""Get the CUDA version from nvcc.
|
||||
|
||||
@ -541,12 +517,6 @@ def get_vllm_version() -> str:
|
||||
rocm_version = get_rocm_version() or torch.version.hip
|
||||
if rocm_version and rocm_version != MAIN_CUDA_VERSION:
|
||||
version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
|
||||
elif _is_neuron():
|
||||
# Get the Neuron version
|
||||
neuron_version = str(get_neuronxcc_version())
|
||||
if neuron_version != MAIN_CUDA_VERSION:
|
||||
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||
version += f"{sep}neuron{neuron_version_str}"
|
||||
elif _is_tpu():
|
||||
version += f"{sep}tpu"
|
||||
elif _is_cpu():
|
||||
@ -591,8 +561,6 @@ def get_requirements() -> list[str]:
|
||||
requirements = modified_requirements
|
||||
elif _is_hip():
|
||||
requirements = _read_requirements("rocm.txt")
|
||||
elif _is_neuron():
|
||||
requirements = _read_requirements("neuron.txt")
|
||||
elif _is_tpu():
|
||||
requirements = _read_requirements("tpu.txt")
|
||||
elif _is_cpu():
|
||||
@ -601,7 +569,7 @@ def get_requirements() -> list[str]:
|
||||
requirements = _read_requirements("xpu.txt")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
|
||||
"Unsupported platform, please use CUDA, ROCm, or CPU.")
|
||||
return requirements
|
||||
|
||||
|
||||
|
@ -287,15 +287,6 @@ def test_prefix_cache_default():
|
||||
},
|
||||
"mm-processor-kwargs"
|
||||
),
|
||||
(
|
||||
'{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}',
|
||||
{
|
||||
"cast_logits_dtype": "bfloat16",
|
||||
"sequence_parallel_norm": True,
|
||||
"sequence_parallel_norm_threshold": 2048,
|
||||
},
|
||||
"override-neuron-config"
|
||||
),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_composite_arg_parser(arg, expected, option):
|
||||
|
@ -1,43 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.activation import FastGELU, SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"])
|
||||
@pytest.mark.parametrize("num_tokens,d,dtype", [
|
||||
(7, 512, torch.half),
|
||||
(7, 512, torch.float),
|
||||
(83, 512, torch.half),
|
||||
])
|
||||
@torch.inference_mode()
|
||||
def test_act_and_mul(
|
||||
activation: str,
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device("cpu")
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device)
|
||||
if activation == "silu_and_mul":
|
||||
layer = SiluAndMul()
|
||||
fn = layer.forward_native
|
||||
elif activation == "gelu_fast":
|
||||
layer = FastGELU()
|
||||
fn = F.gelu
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"activation {activation} is not implemented.")
|
||||
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
|
||||
out = layer.to(device=device).forward_neuron(x)
|
||||
ref_out = fn(x.cpu())
|
||||
torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0)
|
@ -1,154 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import neuronxcc.nki.language as nl
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from neuronxcc import nki
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import (
|
||||
load_block_tables, transform_block_tables_for_indirect_load)
|
||||
|
||||
|
||||
def is_power_of_2(n):
|
||||
return n > 0 and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def nki_load_and_transform_block_tables(
|
||||
block_tables,
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
num_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
):
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
|
||||
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
|
||||
num_blocks_per_tile)
|
||||
|
||||
# we need to pass an Index as head_id
|
||||
head_id = nl.arange(1)[None, :] + head_id
|
||||
|
||||
block_tables_transposed = transform_block_tables_for_indirect_load(
|
||||
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
|
||||
B_P_SIZE = 128
|
||||
assert block_tables_transposed.shape[1] == B_P_SIZE
|
||||
|
||||
out = nl.ndarray(
|
||||
block_tables_transposed.shape,
|
||||
dtype=nl.int32,
|
||||
buffer=nl.shared_hbm,
|
||||
)
|
||||
for i in nl.affine_range(block_tables_transposed.shape[0]):
|
||||
nl.store(dst=out[i], value=block_tables_transposed[i])
|
||||
return out
|
||||
|
||||
|
||||
def ref_block_tables_transform(
|
||||
block_tables,
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
num_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
):
|
||||
assert block_tables.numel() == num_tiles * num_blocks_per_tile
|
||||
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
|
||||
B_F_SIZE = 128
|
||||
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
|
||||
block_tables = F.pad(
|
||||
block_tables,
|
||||
(0, 0, 0, num_tiles_padded - num_tiles),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
block_tables = block_tables * num_head + head_id
|
||||
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
|
||||
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
|
||||
block_tables = block_tables * block_size_tiling_factor + offset
|
||||
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()
|
||||
|
||||
num_blocks_per_tile = block_tables_transposed.shape[0]
|
||||
assert num_blocks_per_tile % B_F_SIZE == 0
|
||||
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
|
||||
B_F_SIZE, num_tiles_padded)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"q_head_per_kv_head,head_id",
|
||||
[
|
||||
(1, 0),
|
||||
(3, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_tiles,num_blocks_per_tile",
|
||||
[
|
||||
(1, 1),
|
||||
(13, 16),
|
||||
(17, 128),
|
||||
(35, 512),
|
||||
(128, 128),
|
||||
(130, 64),
|
||||
(280, 256),
|
||||
(315, 1),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_load_and_transform_block_tables(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
q_head_per_kv_head,
|
||||
head_id,
|
||||
) -> None:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
compiler_flags_str = " ".join([
|
||||
"-O1",
|
||||
"--retry_failed_compilation",
|
||||
])
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
|
||||
|
||||
torch.manual_seed(10000)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
|
||||
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||
B_P_SIZE = 128
|
||||
if num_blocks_per_tile < B_P_SIZE:
|
||||
assert B_P_SIZE % num_blocks_per_tile == 0
|
||||
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
|
||||
else:
|
||||
block_size_tiling_factor = 1
|
||||
max_num_blocks = 100000
|
||||
block_tables = torch.randint(
|
||||
0,
|
||||
max_num_blocks,
|
||||
(num_tiles * num_blocks_per_tile, ),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
|
||||
block_tables.to(device=device),
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
q_head_per_kv_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
).cpu()
|
||||
ref_out = ref_block_tables_transform(
|
||||
block_tables,
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
q_head_per_kv_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
)
|
||||
assert (nki_out.shape == ref_out.shape
|
||||
), f"{nki_out.shape=} != {ref_out.shape=}"
|
||||
assert torch.all(nki_out == ref_out)
|
@ -1,86 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import reshape_and_cache
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens, n_kv_head, d_head, num_blocks, block_size",
|
||||
[
|
||||
# Small model configuration (e.g., GPT-2 small)
|
||||
(32, 12, 64, 4, 128), # Typical sequence processing
|
||||
(1, 12, 64, 4, 128), # Single token update
|
||||
(128, 12, 64, 4, 128), # Longer sequence
|
||||
|
||||
# Medium model configuration (e.g., GPT-2 medium)
|
||||
(64, 16, 96, 8, 256), # Standard batch
|
||||
(256, 16, 96, 8, 256), # Large batch
|
||||
|
||||
# Large model configuration (e.g., GPT-3 style)
|
||||
(48, 32, 128, 16, 512), # Typical processing window
|
||||
(512, 32, 128, 16, 512), # Full context window
|
||||
|
||||
# Edge cases and stress tests
|
||||
(1024, 8, 32, 32, 32), # Many tokens, small heads
|
||||
(16, 64, 256, 4, 64), # Few tokens, many heads
|
||||
(2048, 24, 128, 64, 128), # Large scale test
|
||||
|
||||
# Minimal configurations for debugging
|
||||
(4, 2, 16, 2, 16), # Tiny test case
|
||||
(1, 1, 8, 1, 8), # Minimal possible
|
||||
])
|
||||
def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
|
||||
block_size):
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create CPU tensors for reference implementation
|
||||
key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
|
||||
torch.tensor(d_head))
|
||||
value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
|
||||
torch.tensor(d_head))
|
||||
key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
|
||||
value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
|
||||
slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens]
|
||||
|
||||
# Run reference implementation on CPU
|
||||
block_indices = torch.div(slot_mapping_cpu,
|
||||
block_size,
|
||||
rounding_mode="floor")
|
||||
block_offsets = slot_mapping_cpu % block_size
|
||||
|
||||
for i in range(num_tokens):
|
||||
block_idx = block_indices[i]
|
||||
block_offset = block_offsets[i]
|
||||
key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i]
|
||||
value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i]
|
||||
|
||||
# Create XLA device tensors
|
||||
device = torch.device('xla')
|
||||
key = key_cpu.to(device)
|
||||
value = value_cpu.to(device)
|
||||
key_cache = torch.zeros_like(key_cache_cpu, device=device)
|
||||
value_cache = torch.zeros_like(value_cache_cpu, device=device)
|
||||
slot_mapping = slot_mapping_cpu.to(device)
|
||||
kv_cache = torch.stack([key_cache, value_cache])
|
||||
|
||||
# Run vectorized implementation on XLA device
|
||||
reshape_and_cache(key, value, kv_cache, slot_mapping)
|
||||
key_cache, value_cache = torch.unbind(kv_cache, dim=0)
|
||||
|
||||
# Move results back to CPU for comparison
|
||||
key_cache_result = key_cache.cpu()
|
||||
value_cache_result = value_cache.cpu()
|
||||
|
||||
# Assert results match
|
||||
torch.testing.assert_close(key_cache_result,
|
||||
key_cache_cpu,
|
||||
rtol=1e-5,
|
||||
atol=1e-5)
|
||||
torch.testing.assert_close(value_cache_result,
|
||||
value_cache_cpu,
|
||||
rtol=1e-5,
|
||||
atol=1e-5)
|
@ -1,57 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
|
||||
(7, 8, False, torch.half),
|
||||
(83, 768, False, torch.half),
|
||||
(83, 768, True, torch.half),
|
||||
(83, 768, True, torch.bfloat16),
|
||||
(83, 768, True, torch.float32),
|
||||
])
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device("cpu")
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
residual_cpu = residual.cpu() if add_residual else None
|
||||
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
|
||||
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
|
||||
out = layer.to(device=device)(x, residual)
|
||||
|
||||
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||
# numerical errors than other operators because they involve reductions.
|
||||
# Therefore, we use a larger tolerance.
|
||||
if add_residual:
|
||||
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
|
||||
torch.testing.assert_close(out[0].cpu(),
|
||||
ref_out[0],
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(out[1].cpu(),
|
||||
ref_out[1],
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
else:
|
||||
assert out.is_xla, "output tensor is expected to be XLA tensor"
|
||||
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)
|
@ -1,95 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class MockLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, vocab_size: int, scale: float,
|
||||
fake_logits: torch.Tensor):
|
||||
super().__init__(vocab_size=vocab_size, scale=scale)
|
||||
self.fake_logits = fake_logits.clone()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch(
|
||||
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
|
||||
lambda x, y: x
|
||||
), patch(
|
||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
||||
return input_tensor, fake_logits, logits_processor
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(8))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_logits_processors(seed: int):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device("cpu")
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||
def pick_ith(token_ids, logits):
|
||||
logits[len(token_ids)] = float("inf")
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
logits_processor_output = logits_processor(
|
||||
lm_head=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
fake_logits *= logits_processor.scale
|
||||
torch.testing.assert_close(logits_processor_output[:, 1],
|
||||
fake_logits[:, 1],
|
||||
rtol=1e-4,
|
||||
atol=0.0)
|
@ -1,127 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.neuron import NeuronFramework
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||
|
||||
os.environ[
|
||||
'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value
|
||||
|
||||
|
||||
def _create_neuron_model_runner(model: str, *args,
|
||||
**kwargs) -> NeuronModelRunner:
|
||||
engine_args = EngineArgs(model, *args, **kwargs)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
vllm_config = VllmConfig(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
)
|
||||
neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config)
|
||||
return neuron_model_runner
|
||||
|
||||
|
||||
def test_update_neuron_sampling_params_not_full_batch():
|
||||
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
|
||||
model_runner = _create_neuron_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_num_seqs=2,
|
||||
)
|
||||
assert not model_runner._on_device_sampling_disabled
|
||||
# Test sampling param updating only when TNx is framework
|
||||
# NxDI handles sampling parameter updating inside model
|
||||
if current_platform.use_transformers_neuronx():
|
||||
model_mock = MagicMock()
|
||||
model_runner.model = model_mock
|
||||
|
||||
seq_group_metadata_list = [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_0",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0.5,
|
||||
top_k=1,
|
||||
top_p=0.5),
|
||||
block_tables={0: [1]},
|
||||
)
|
||||
]
|
||||
|
||||
model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
|
||||
# Index neuron sampling parameters based on block_tables indices.
|
||||
# The first block_id of the sequence 0 is 1, so its parameters are
|
||||
# placed at index 1. So the sampling parameters will be:
|
||||
# Index 0: default sampling parameters
|
||||
# Index 1: sequecne 0's sampling parameters.
|
||||
neuron_sampling_params = (
|
||||
model_runner.model_config.neuron_sampling_params)
|
||||
assert neuron_sampling_params.temperature == [1.0, 0.5]
|
||||
assert neuron_sampling_params.top_k == [
|
||||
model_runner._MAX_NEURON_SAMPLING_TOP_K, 1
|
||||
]
|
||||
assert neuron_sampling_params.top_p == [1.0, 0.5]
|
||||
model_mock.model.update_generation_config.assert_called_once_with(
|
||||
neuron_sampling_params)
|
||||
|
||||
|
||||
def test_update_neuron_sampling_params_full_batch():
|
||||
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
|
||||
model_runner = _create_neuron_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_num_seqs=2,
|
||||
)
|
||||
assert not model_runner._on_device_sampling_disabled
|
||||
|
||||
# Test sampling param updating only when TNx is framework
|
||||
# NxDI handles sampling parameter updating inside model
|
||||
if current_platform.use_transformers_neuronx():
|
||||
model_mock = MagicMock()
|
||||
model_runner.model = model_mock
|
||||
|
||||
seq_group_metadata_list = [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_0",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0.5,
|
||||
top_k=1,
|
||||
top_p=0.5),
|
||||
block_tables={0: [1]},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_0",
|
||||
is_prompt=True,
|
||||
seq_data={1: SequenceData.from_seqs([4, 5, 6])},
|
||||
sampling_params=SamplingParams(temperature=0.2,
|
||||
top_k=2,
|
||||
top_p=0.2),
|
||||
block_tables={1: [0]},
|
||||
)
|
||||
]
|
||||
|
||||
model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
|
||||
# Index neuron sampling parameters based on block_tables indices.
|
||||
# The first block_id of the sequence 0 is 1, so its parameters are
|
||||
# placed at index 1. So the sampling parameters will be:
|
||||
# Index 0: sequence 1's sampling parameters
|
||||
# Index 1: sequecne 0's sampling parameters.
|
||||
neuron_sampling_params = (
|
||||
model_runner.model_config.neuron_sampling_params)
|
||||
assert neuron_sampling_params.temperature == [0.2, 0.5]
|
||||
assert neuron_sampling_params.top_k == [2, 1]
|
||||
assert neuron_sampling_params.top_p == [0.2, 0.5]
|
||||
model_mock.model.update_generation_config.assert_called_once_with(
|
||||
neuron_sampling_params)
|
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.model_executor.layers.quantization.neuron_quant import (
|
||||
NeuronQuantConfig)
|
||||
|
||||
|
||||
def test_get_supported_act_dtypes():
|
||||
neuron_quant_config = NeuronQuantConfig()
|
||||
supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes()
|
||||
target_list = ["any_dtype1", "any_dtype2"]
|
||||
for dtype in target_list:
|
||||
assert dtype in supported_act_dtypes
|
@ -1,514 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
class BlockDiagonalCausalFromBottomRightMask:
|
||||
|
||||
@staticmethod
|
||||
def _from_seqlens(query_lens, seq_lens, block_size=None):
|
||||
from torch import logical_and, logical_or
|
||||
|
||||
contexted = block_size is None
|
||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||
n_queries = sum(query_lens)
|
||||
num_seqs = len(query_lens)
|
||||
if contexted:
|
||||
key_lens_blockaligned = seq_lens
|
||||
else:
|
||||
n_blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||
offset_per_seq = n_blocks_per_seq * block_size
|
||||
key_lens_blockaligned = offset_per_seq[:num_seqs].tolist()
|
||||
n_keys = sum(key_lens_blockaligned)
|
||||
|
||||
a = (torch.arange(n_queries).reshape(n_queries,
|
||||
1).expand(n_queries, n_keys))
|
||||
b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys)
|
||||
q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0)
|
||||
k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0)
|
||||
|
||||
prior_mask = torch.zeros(n_queries, n_keys)
|
||||
new_masks: list[torch.Tensor] = []
|
||||
for seq_id in range(num_seqs):
|
||||
ri = q_cumsum[seq_id]
|
||||
ci = k_cumsum[seq_id]
|
||||
nr = query_lens[seq_id]
|
||||
|
||||
if contexted:
|
||||
nc = seq_lens[seq_id]
|
||||
a_offset = ci + nc - ri - nr
|
||||
new_mask = (a + a_offset) >= b
|
||||
else:
|
||||
nc = context_lens[seq_id]
|
||||
a_offset = ci + nc - 1
|
||||
new_mask = a_offset >= b
|
||||
|
||||
left_mask = b >= ci
|
||||
top_mask = a >= ri
|
||||
bottom_mask = a < (ri + nr)
|
||||
|
||||
new_mask = logical_and(
|
||||
logical_and(logical_and(new_mask, left_mask), top_mask),
|
||||
bottom_mask,
|
||||
)
|
||||
prior_mask = logical_or(prior_mask, new_mask)
|
||||
new_masks = new_masks + [new_mask]
|
||||
return prior_mask
|
||||
|
||||
@staticmethod
|
||||
def from_seqlens(query_lens, seq_lens, block_size=None):
|
||||
contexted = block_size is None
|
||||
if contexted:
|
||||
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
|
||||
query_lens, seq_lens)
|
||||
active_mask = None
|
||||
else:
|
||||
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
|
||||
query_lens, seq_lens, block_size)
|
||||
active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
|
||||
query_lens, query_lens)
|
||||
return prior_mask, active_mask
|
||||
|
||||
|
||||
def ref_softmax(x: torch.Tensor,
|
||||
dim: int,
|
||||
mixed_precision=False,
|
||||
return_max_reduce=False):
|
||||
max_value = torch.amax(x, dim=dim, keepdims=True)
|
||||
exp = torch.exp(x - max_value)
|
||||
if mixed_precision:
|
||||
sum_value = torch.sum(exp.astype(torch.float32),
|
||||
dim=dim,
|
||||
keepdims=True).astype(x.dtype)
|
||||
else:
|
||||
sum_value = torch.sum(exp, dim=dim, keepdims=True)
|
||||
if return_max_reduce:
|
||||
return exp / sum_value, max_value, torch.reciprocal(sum_value)
|
||||
return exp / sum_value
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
return_max_reduce: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
masked_score = scaled_qk + attn_mask.float()
|
||||
if return_max_reduce:
|
||||
norm_score, cached_max, cached_sum_reciprocal = ref_softmax(
|
||||
masked_score, dim=-1, return_max_reduce=True)
|
||||
else:
|
||||
norm_score = ref_softmax(masked_score, dim=-1)
|
||||
out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
|
||||
if return_max_reduce:
|
||||
return (
|
||||
out,
|
||||
cached_max,
|
||||
cached_sum_reciprocal,
|
||||
norm_score,
|
||||
masked_score,
|
||||
scaled_qk,
|
||||
)
|
||||
else:
|
||||
return (out, )
|
||||
|
||||
|
||||
def ref_context_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
head_size,
|
||||
num_queries_per_kv,
|
||||
return_max_reduce=False,
|
||||
):
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||
|
||||
attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
query_lens, seq_lens)
|
||||
|
||||
# convert binary mask to -inf values
|
||||
attn_mask = torch.logical_not(attn_mask)
|
||||
attn_mask = attn_mask.float() * -30000
|
||||
|
||||
output, *debug_tensors = ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
attn_mask,
|
||||
return_max_reduce=return_max_reduce,
|
||||
)
|
||||
|
||||
output = output.unsqueeze(1)
|
||||
if return_max_reduce:
|
||||
cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
||||
debug_tensors)
|
||||
return (
|
||||
output,
|
||||
cached_max,
|
||||
cached_sum_reciprocal,
|
||||
lse,
|
||||
masked_score,
|
||||
scaled_qk,
|
||||
)
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def sample_inputs(
|
||||
prefill_batch_size,
|
||||
decode_batch_size,
|
||||
min_query_len,
|
||||
max_query_len,
|
||||
min_ctx_len,
|
||||
max_ctx_len,
|
||||
block_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
):
|
||||
batch_size = prefill_batch_size + decode_batch_size
|
||||
max_model_len = (max_query_len + max_ctx_len) * 4
|
||||
max_block_per_request = max_model_len // block_size
|
||||
cache_size = (batch_size * max_block_per_request) + 2
|
||||
prefill_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (prefill_batch_size, ),
|
||||
dtype=torch.long).tolist()
|
||||
decode_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (decode_batch_size, ),
|
||||
dtype=torch.long).tolist()
|
||||
ctx_lens = prefill_ctx_lens + decode_ctx_lens
|
||||
query_lens = torch.randint(
|
||||
min_query_len,
|
||||
max_query_len + 1,
|
||||
(prefill_batch_size, ),
|
||||
dtype=torch.long,
|
||||
).tolist() + [1 for _ in range(decode_batch_size)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1, 1)
|
||||
torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
|
||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||
kv.uniform_(-1, 1)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:batch_size * max_block_per_request].view(
|
||||
batch_size, max_block_per_request)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
for i in range(batch_size):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||
j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||
b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||
if cur_ctx + block_size > b_ctx_len[i]:
|
||||
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||
else:
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc])
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
kv_cache = torch.stack([k_cache, v_cache])
|
||||
|
||||
return (
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
)
|
||||
|
||||
|
||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||
num_blocks):
|
||||
context_lens = seq_lens - query_lens
|
||||
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||
num_seqs = len(seq_lens)
|
||||
active_blocks: list[int] = []
|
||||
for seq_id in range(num_seqs):
|
||||
active_blocks = (
|
||||
active_blocks +
|
||||
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
||||
return F.pad(
|
||||
torch.tensor(active_blocks, dtype=torch.int32),
|
||||
(0, num_blocks - len(active_blocks)),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision",
|
||||
[
|
||||
# Test minimal configurations (small block size)
|
||||
(1, 199, 1, 512, 4, 2, 8, False
|
||||
), # minimal block size, small dimensions
|
||||
(1, 199, 1, 512, 4, 2, 8, True), # same with mixed precision
|
||||
|
||||
# Test common/medium configurations
|
||||
(4, 12, 32, 2048, 32, 8, 64, False), # common case, larger heads
|
||||
(4, 12, 32, 2048, 16, 4, 32,
|
||||
True), # medium size, mixed precision, grouped-query attention (GQA)
|
||||
|
||||
# Test large configurations
|
||||
(4, 12, 256, 8192, 8, 1, 128, False), # large blocks, large head size
|
||||
(4, 12, 256, 8192, 64, 8, 64, True), # large blocks, many heads
|
||||
|
||||
# Test asymmetric configurations
|
||||
(2, 24, 64, 4096, 12, 4, 96, False), # varied batch sizes
|
||||
(8, 8, 128, 2048, 24, 2, 48, True), # balanced batches
|
||||
|
||||
# Test edge cases
|
||||
(1, 128, 16, 1024, 4, 2, 16, False), # large decode batch
|
||||
(16, 4, 8, 1024, 4, 2, 128, True), # large prefill batch
|
||||
(4, 12, 32, 2048, 16, 1, 32, True), # multi-head attention (MHA)
|
||||
(4, 12, 32, 2048, 16, 16, 32, True), # multi-query attention (MQA)
|
||||
])
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
prefill_batch_size: int,
|
||||
decode_batch_size: int,
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
large_tile_size,
|
||||
mixed_precision: bool,
|
||||
) -> None:
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
|
||||
reorder_context_mask)
|
||||
|
||||
assert large_tile_size % block_size == 0
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
compiler_flags_str = " ".join([
|
||||
"-O1",
|
||||
"--retry_failed_compilation",
|
||||
])
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
torch.set_default_device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
min_ctx_len = 32
|
||||
max_ctx_len = 1024
|
||||
min_query_len = 16
|
||||
max_query_len = 512
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
(
|
||||
query,
|
||||
k_active,
|
||||
v_active,
|
||||
kv_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
) = sample_inputs(
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
decode_batch_size=decode_batch_size,
|
||||
min_query_len=min_query_len,
|
||||
max_query_len=max_query_len,
|
||||
min_ctx_len=min_ctx_len,
|
||||
max_ctx_len=max_ctx_len,
|
||||
block_size=block_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
output_ref = ref_context_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
head_size,
|
||||
num_queries_per_kv,
|
||||
return_max_reduce=False,
|
||||
)
|
||||
|
||||
# build neuron program
|
||||
B_P_SIZE = 128
|
||||
assert (large_tile_size >= B_P_SIZE
|
||||
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
|
||||
|
||||
def pad_to_multiple(a, b):
|
||||
return cdiv(a, b) * b
|
||||
|
||||
def pad_to_next_power_of_2(a):
|
||||
assert a > 0
|
||||
return 2**int(a - 1).bit_length()
|
||||
|
||||
# calculate input shapes
|
||||
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
|
||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||
num_active_blocks = cdiv(context_lens, block_size).sum().item()
|
||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||
large_tile_size // block_size)
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (
|
||||
context_kv_len %
|
||||
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
|
||||
|
||||
# pad QKV tensors
|
||||
pad_dims = (
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
max_num_queries - query.shape[0],
|
||||
)
|
||||
query = F.pad(query, pad_dims, "constant", 0)
|
||||
k = F.pad(k_active, pad_dims, "constant", 0)
|
||||
v = F.pad(v_active, pad_dims, "constant", 0)
|
||||
|
||||
# permute QKV tensors
|
||||
# query: (1, n_heads, d, seq_q)
|
||||
# key: (1, n_kv_heads, d, seq_k)
|
||||
# value: (1, n_kv_heads, seq_v, d)
|
||||
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous()
|
||||
|
||||
# transform block table
|
||||
active_block_table = get_active_block_tables(
|
||||
block_table.cpu(),
|
||||
torch.tensor(query_lens).cpu(),
|
||||
torch.tensor(seq_lens).cpu(),
|
||||
block_size,
|
||||
num_active_blocks,
|
||||
)
|
||||
|
||||
# Build attention masks
|
||||
prior_mask, active_mask = (
|
||||
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
query_lens, seq_lens, block_size=block_size))
|
||||
prior_mask_padded = F.pad(
|
||||
prior_mask,
|
||||
(
|
||||
0,
|
||||
context_kv_len - prior_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - prior_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool()
|
||||
active_mask_padded = F.pad(
|
||||
active_mask,
|
||||
(
|
||||
0,
|
||||
max_num_queries - active_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - active_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool()
|
||||
attn_mask = torch.concat([prior_mask_padded, active_mask_padded],
|
||||
dim=1)
|
||||
|
||||
attn_mask = reorder_context_mask(attn_mask, large_tile_size,
|
||||
block_size)
|
||||
|
||||
input_args = (
|
||||
query.to(device=device),
|
||||
k.to(device=device),
|
||||
v.to(device=device),
|
||||
kv_cache.to(device=device),
|
||||
active_block_table.to(device=device),
|
||||
attn_mask.to(device=device),
|
||||
)
|
||||
input_kwargs = dict(
|
||||
n_kv_head=num_kv_heads,
|
||||
head_size=head_size,
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=large_tile_size,
|
||||
)
|
||||
|
||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||
|
||||
num_actual_tokens = sum(query_lens)
|
||||
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)
|
||||
output_nki = output_nki[0, :num_actual_tokens, :, :]
|
||||
output_ref_padded = F.pad(
|
||||
output_ref,
|
||||
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
output_ref = output_ref_padded.transpose(
|
||||
0, 1)[0, :num_actual_tokens, :, :]
|
||||
|
||||
torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0)
|
@ -1,68 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for miscellaneous utilities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [
|
||||
(16, False, 32, 32, 1024, True),
|
||||
(16, False, 32, 128, 1024, True),
|
||||
(16, True, 32, 32, 1024, True),
|
||||
(16, True, 32, 128, 1024, True),
|
||||
(16, False, 32, 128, 1024, False),
|
||||
(16, True, 32, 128, 1024, False),
|
||||
])
|
||||
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
|
||||
head_size, seq_len, use_key):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
batch_size = 1
|
||||
base = 10000
|
||||
num_heads = 8
|
||||
|
||||
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, torch.float32)
|
||||
|
||||
positions = torch.randint(0,
|
||||
max_position, (batch_size, seq_len),
|
||||
device="cpu")
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
assert positions.is_cpu, \
|
||||
"reference input tensor is expected to be CPU tensor."
|
||||
ref_query, ref_key = rot.to(device="cpu").forward_native(
|
||||
positions, query, key)
|
||||
out_query, out_key = rot.to(device=device).forward_neuron(
|
||||
positions.to(device=device), query.to(device=device),
|
||||
key.to(device=device) if key is not None else None)
|
||||
if use_key:
|
||||
assert out_query.is_xla and out_key.is_xla, \
|
||||
"output tensor is expected to be XLA tensor"
|
||||
torch.testing.assert_close(out_key.cpu(),
|
||||
ref_key,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
else:
|
||||
assert out_key is None, "expected returned key to be None"
|
||||
assert out_query.is_xla, \
|
||||
"output tensor is expected to be XLA tensor"
|
||||
torch.testing.assert_close(out_query.cpu(),
|
||||
ref_query,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
@ -1,101 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.utils import get_distributed_init_method, get_open_port
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]:
|
||||
"""Decorator to reinitialize the Neuron Runtime before executing a test.
|
||||
This is necessary for distributed tests which need to reallocate Neuron
|
||||
Cores to separate subprocesses.
|
||||
"""
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
runtime = torch.classes.neuron.Runtime()
|
||||
runtime.initialize()
|
||||
runtime.unsafe_close()
|
||||
|
||||
f(*args, **kwargs)
|
||||
runtime.initialize()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def all_gather_test_worker(index, tp_degree, distributed_init_method):
|
||||
init_distributed_environment(tp_degree,
|
||||
index,
|
||||
distributed_init_method,
|
||||
index,
|
||||
backend="xla")
|
||||
ensure_model_parallel_initialized(tp_degree, 1)
|
||||
|
||||
num_dimensions = 3
|
||||
tensor_size = list(range(2, num_dimensions + 2))
|
||||
total_size = 1
|
||||
for s in tensor_size:
|
||||
total_size *= s
|
||||
|
||||
all_gather_dimension = -1
|
||||
all_tensors = [
|
||||
torch.arange(total_size, dtype=torch.float32,
|
||||
device="xla").reshape(tensor_size) * (r + 1)
|
||||
for r in range(tp_degree)
|
||||
]
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
t = all_tensors[index % tp_degree]
|
||||
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
def all_reduce_test_worker(index, tp_degree, distributed_init_method):
|
||||
init_distributed_environment(tp_degree,
|
||||
index,
|
||||
distributed_init_method,
|
||||
index,
|
||||
backend="xla")
|
||||
ensure_model_parallel_initialized(tp_degree, 1)
|
||||
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1)
|
||||
for r in range(tp_degree)
|
||||
]
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[index % tp_degree]
|
||||
t = tensor_model_parallel_all_reduce(t)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("test_target",
|
||||
[all_reduce_test_worker, all_gather_test_worker])
|
||||
@reinitialize_neuron_runtime
|
||||
def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size,
|
||||
test_target):
|
||||
|
||||
with patch('torch_xla._XLAC._xla_runtime_is_initialized',
|
||||
return_value=False):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
"127.0.0.1", get_open_port())
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size))
|
||||
monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES",
|
||||
','.join(['1' for _ in range(tp_size)]))
|
||||
|
||||
xmp.spawn(test_target, args=(tp_size, distributed_init_method))
|
@ -1,83 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def patch_eagle_draft_with_lm_head(target_model_id: str,
|
||||
draft_model_id: str) -> str:
|
||||
# In NxDI, draft model checkpoint must include lm_head weights from target
|
||||
# model. For more details see https://awsdocs-neuron.readthedocs-hosted.com
|
||||
# /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html
|
||||
# #eagle-checkpoint-compatibility
|
||||
final_draft_dir = "/tmp/patched_eagle_draft"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
target_dir = snapshot_download(repo_id=target_model_id,
|
||||
local_dir=os.path.join(
|
||||
tmp_dir, "target"))
|
||||
draft_dir = snapshot_download(repo_id=draft_model_id,
|
||||
local_dir=os.path.join(tmp_dir, "draft"))
|
||||
|
||||
lm_head_key = "lm_head.weight"
|
||||
index_path = os.path.join(target_dir, "model.safetensors.index.json")
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
shard_name = index["weight_map"][lm_head_key]
|
||||
target_safetensor_path = os.path.join(target_dir, shard_name)
|
||||
|
||||
with safe_open(target_safetensor_path, framework="pt") as f:
|
||||
target_lm_head = f.get_tensor(lm_head_key)
|
||||
|
||||
draft_path = os.path.join(draft_dir, "pytorch_model.bin")
|
||||
draft_state_dict = torch.load(draft_path, map_location="cpu")
|
||||
draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16)
|
||||
torch.save(draft_state_dict, draft_path)
|
||||
|
||||
shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True)
|
||||
|
||||
return final_draft_dir
|
||||
|
||||
|
||||
def test_eagle():
|
||||
patched_draft_path = patch_eagle_draft_with_lm_head(
|
||||
target_model_id="meta-llama/Llama-2-7b-hf",
|
||||
draft_model_id="yuhuili/EAGLE-llama2-chat-7B")
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
speculative_config={
|
||||
"model": patched_draft_path,
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 128
|
||||
},
|
||||
max_num_seqs=1,
|
||||
max_model_len=128,
|
||||
tensor_parallel_size=2,
|
||||
override_neuron_config={
|
||||
"enable_eagle_speculation": True,
|
||||
"enable_fused_speculation": True,
|
||||
"fused_qkv": True
|
||||
},
|
||||
)
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
]
|
||||
outputs = llm.generate(prompts, SamplingParams(top_k=1))
|
||||
expected_output = " the head of state and head of government of " \
|
||||
"the United States. The president direct"
|
||||
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
|
||||
assert (expected_output == generated_text)
|
||||
|
||||
print("Neuron Eagle speculation test passed.")
|
@ -1,64 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_mistral():
|
||||
llm = LLM(model="mistralai/Mistral-7B-v0.1",
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=4,
|
||||
max_model_len=128,
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled": False,
|
||||
"skip_warmup": True
|
||||
})
|
||||
|
||||
# Send more prompts than the compiled batch size (4) and request
|
||||
# varying generation lengths to test accuracy related to Neuron
|
||||
# specific sequence id sorting.
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"What is Annapurna labs?",
|
||||
"I believe the meaning of life is",
|
||||
"Tell me a story about a brave knight",
|
||||
"Hello, my name is Llama",
|
||||
]
|
||||
|
||||
sampling_params = [
|
||||
SamplingParams(top_k=1, max_tokens=10),
|
||||
SamplingParams(top_k=1, max_tokens=20),
|
||||
SamplingParams(top_k=1, max_tokens=30),
|
||||
SamplingParams(top_k=1, max_tokens=40),
|
||||
SamplingParams(top_k=1, max_tokens=50),
|
||||
SamplingParams(top_k=1, max_tokens=60)
|
||||
]
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
expected_outputs = [
|
||||
" the most powerful person in the world. He is",
|
||||
" a city of many faces. It is a city of history, culture, art, "
|
||||
"fashion, and",
|
||||
"\n\nAnnapurna Labs is a semiconductor company that was founded "
|
||||
"in 2013 by Amazon. The company is",
|
||||
" to be happy.\n\nI believe that happiness is a choice.\n\nI "
|
||||
"believe that happiness is a state of mind.\n\nI believe that "
|
||||
"happiness is a journey.\n\nI believe",
|
||||
" who rescued a princess from a dragon.\n\nTell me a story about"
|
||||
" a princess who rescued herself from a dragon.\n\nTell me a "
|
||||
"story about a princess who rescued herself from a dragon and "
|
||||
"then rescued a knight from",
|
||||
" and I am a 10 year old male. I am a very friendly and "
|
||||
"affectionate boy who loves to be around people. I am a very "
|
||||
"active boy who loves to play and run around. I am a very smart "
|
||||
"boy who loves to learn new things. I am a very loyal boy"
|
||||
]
|
||||
|
||||
for expected_output, output in zip(expected_outputs, outputs):
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
|
||||
assert (expected_output == generated_text)
|
||||
|
||||
print("Neuron Mistral test passed.")
|
@ -1,97 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def test_llama_single_lora():
|
||||
sql_lora_files = snapshot_download(
|
||||
repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-hf",
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=4,
|
||||
max_model_len=512,
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled": False,
|
||||
"skip_warmup": True,
|
||||
"lora_modules": [{
|
||||
"name": "lora_id_1",
|
||||
"path": sql_lora_files
|
||||
}]
|
||||
},
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=256,
|
||||
device="neuron")
|
||||
"""For multi-lora requests using NxDI as the backend, only the lora_name
|
||||
needs to be specified. The lora_id and lora_path are supplied at the LLM
|
||||
class/server initialization, after which the paths are handled by NxDI"""
|
||||
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
outputs = llm.generate(prompts,
|
||||
SamplingParams(top_k=1),
|
||||
lora_request=[lora_req_1, lora_req_1])
|
||||
|
||||
expected_outputs = [
|
||||
" the head of state and head of government of the United States. "
|
||||
"The president direct",
|
||||
" a city of contrasts. The city is home to the Eiffel Tower"
|
||||
]
|
||||
|
||||
for expected_output, output in zip(expected_outputs, outputs):
|
||||
generated_text = output.outputs[0].text
|
||||
assert (expected_output == generated_text)
|
||||
|
||||
|
||||
def test_llama_multiple_lora():
|
||||
sql_lora_files = snapshot_download(
|
||||
repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-hf",
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=4,
|
||||
max_model_len=512,
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled":
|
||||
False,
|
||||
"skip_warmup":
|
||||
True,
|
||||
"lora_modules": [{
|
||||
"name": "lora_id_1",
|
||||
"path": sql_lora_files
|
||||
}, {
|
||||
"name": "lora_id_2",
|
||||
"path": sql_lora_files
|
||||
}]
|
||||
},
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=256,
|
||||
device="neuron")
|
||||
"""For multi-lora requests using NxDI as the backend, only the lora_name
|
||||
needs to be specified. The lora_id and lora_path are supplied at the LLM
|
||||
class/server initialization, after which the paths are handled by NxDI"""
|
||||
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
|
||||
lora_req_2 = LoRARequest("lora_id_2", 1, " ")
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
outputs = llm.generate(prompts,
|
||||
SamplingParams(top_k=1),
|
||||
lora_request=[lora_req_1, lora_req_2])
|
||||
|
||||
expected_outputs = [
|
||||
" the head of state and head of government of the United States. "
|
||||
"The president direct",
|
||||
" a city of contrasts. The city is home to the Eiffel Tower"
|
||||
]
|
||||
|
||||
for expected_output, output in zip(expected_outputs, outputs):
|
||||
generated_text = output.outputs[0].text
|
||||
assert (expected_output == generated_text)
|
@ -1,903 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import neuronxcc.nki.isa as nisa
|
||||
import neuronxcc.nki.language as nl
|
||||
import numpy as np
|
||||
import torch
|
||||
from neuronxcc import nki
|
||||
from neuronxcc.nki.language import par_dim
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
||||
"""
|
||||
Load block tables from HBM into SRAM
|
||||
|
||||
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
|
||||
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
|
||||
# reshape as `(num_tiles, num_blocks_per_tile)`
|
||||
assert len(block_tables_hbm.shape) == 1
|
||||
(num_total_blocks, ) = block_tables_hbm.shape
|
||||
assert num_blocks_per_tile * num_tiles == num_total_blocks
|
||||
block_tables_hbm = block_tables_hbm.reshape(
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
|
||||
block_tables_sbuf = nl.zeros(
|
||||
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
||||
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
||||
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
|
||||
dtype=nl.int32,
|
||||
mask=(i_p + i * B_P_SIZE < num_tiles),
|
||||
)
|
||||
return block_tables_sbuf
|
||||
|
||||
|
||||
@nki.jit
|
||||
def transform_block_tables_for_indirect_load(
|
||||
block_tables,
|
||||
block_size_tiling_factor,
|
||||
num_head,
|
||||
head_id,
|
||||
):
|
||||
"""
|
||||
This function does two things:
|
||||
1. calculate new `block_tables` for a `head_id` after flattening
|
||||
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
|
||||
2. transpose the result so that `block_table` for each tile is mapped to
|
||||
SBUF Partition dimension for vectorized DMA
|
||||
|
||||
Tiling trick to further improve DMA performance:
|
||||
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
|
||||
blocks of a given `head_id` from HBM, the load `cache[block_tables,
|
||||
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
|
||||
fully utilize hardware parallelization. The solution is to tile `block_size`
|
||||
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
|
||||
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
|
||||
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
|
||||
|
||||
Note:
|
||||
We don't further tile D dimension as small DMA size also hurts performance.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
|
||||
block_tables.shape)
|
||||
assert num_tiles_per_partition == B_P_SIZE
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
||||
|
||||
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
|
||||
block_tables_transposed = nl.ndarray(
|
||||
(
|
||||
num_loads,
|
||||
par_dim(B_P_SIZE),
|
||||
num_partitions * num_tiles_per_partition,
|
||||
),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
|
||||
# prepare iota ahead of time to avoid repeatedly using Gpsimd
|
||||
if num_head > 1:
|
||||
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
|
||||
head_id = nl.transpose(
|
||||
head_id.broadcast_to((1, num_tiles_per_partition)))
|
||||
if num_blocks_per_tile > 1:
|
||||
head_id = head_id.broadcast_to(
|
||||
(num_tiles_per_partition, num_blocks_per_tile))
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
broadcast_shape = (
|
||||
num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
block_size_tiling_factor,
|
||||
)
|
||||
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
|
||||
dtype=nl.int32).broadcast_to(broadcast_shape)
|
||||
|
||||
for partition_id in nl.affine_range(num_partitions):
|
||||
block_tables_partition = block_tables[partition_id]
|
||||
if num_head > 1:
|
||||
# fuse num_block and num_head dimension
|
||||
block_tables_partition = block_tables_partition * num_head + head_id
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
# need to apply block size tiling trick
|
||||
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
|
||||
block_tables_partition = ((block_tables_partition *
|
||||
block_size_tiling_factor).reshape(
|
||||
(num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
1)).broadcast_to(broadcast_shape))
|
||||
new_block_tables = block_tables_partition + offset
|
||||
new_block_tables = new_block_tables.reshape(
|
||||
(num_tiles_per_partition, B_P_SIZE))
|
||||
else:
|
||||
new_block_tables = block_tables_partition
|
||||
|
||||
# transpose the block table so that it can be used by vector DGE
|
||||
for i in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = (partition_id * num_tiles_per_partition +
|
||||
nl.arange(num_tiles_per_partition)[None, :])
|
||||
block_tables_transposed[i, i_p, i_f] = nl.transpose(
|
||||
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
|
||||
return block_tables_transposed
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_kv_tile_from_cache(
|
||||
cur_k_tile,
|
||||
cur_v_tile,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
large_k_tile_idx,
|
||||
num_blocks_per_large_tile,
|
||||
tiled_block_size,
|
||||
B_P_SIZE,
|
||||
B_D_SIZE,
|
||||
):
|
||||
"""
|
||||
Load KV cache and transform Key and Value into layout required by Matmul
|
||||
|
||||
Vectorized DMA Load layout:
|
||||
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
|
||||
Layout used by attention matmuls:
|
||||
Key: (par_dim(B_D_SIZE), seqlen_kv)
|
||||
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
|
||||
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
"""
|
||||
# load key cache
|
||||
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_k_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
||||
# Transpose SBUF tensor using PE
|
||||
for tb_i in nl.affine_range(tiled_block_size):
|
||||
cur_k_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
),
|
||||
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
|
||||
|
||||
# load value cache
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
cur_v_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * tiled_block_size * B_D_SIZE,
|
||||
tiled_block_size * B_D_SIZE,
|
||||
),
|
||||
] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
def transpose_p_local(p_local_transposed,
|
||||
p_local,
|
||||
LARGE_TILE_SZ,
|
||||
B_F_SIZE=512):
|
||||
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||
buffer=nl.sbuf,
|
||||
dtype=p_local.dtype)
|
||||
else:
|
||||
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||
buffer=nl.psum,
|
||||
dtype=np.float32)
|
||||
|
||||
for j in nl.affine_range(B_F_SIZE // 128):
|
||||
j_128_slice = nl.ds(j * 128, 128)
|
||||
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
|
||||
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
|
||||
p_local[:, i_j_128_slice])
|
||||
else:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
|
||||
p_local[:, i_j_128_slice])
|
||||
|
||||
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
|
||||
p_local_t_tmp, dtype=p_local_transposed.dtype)
|
||||
|
||||
|
||||
@nki.jit
|
||||
def _flash_attention_core(
|
||||
q_local_tile,
|
||||
k,
|
||||
v,
|
||||
o_buffer,
|
||||
l_buffer,
|
||||
m_buffer,
|
||||
kernel_dtype,
|
||||
acc_type,
|
||||
tile_mask,
|
||||
use_causal_mask,
|
||||
q_tile_idx=None,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=2048,
|
||||
B_P_SIZE=128,
|
||||
B_F_SIZE=512,
|
||||
B_D_SIZE=128,
|
||||
qk_res_buffer=None,
|
||||
):
|
||||
"""
|
||||
The flash attention core function to calculate self attention between a tile
|
||||
of q and a block of K and V.
|
||||
The q_local_tile has (B_P_SIZE, B_D_SIZE)
|
||||
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
|
||||
be split into size B_F_SIZE tiles
|
||||
|
||||
The results are stored in the following three buffers
|
||||
o_buffer: (B_P_SIZE, d)
|
||||
l_buffer: (B_P_SIZE, 1)
|
||||
m_buffer: (B_P_SIZE, 1)
|
||||
|
||||
All IO buffers are in SBUF.
|
||||
"""
|
||||
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||
|
||||
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
buffer=nl.sbuf,
|
||||
dtype=acc_type)
|
||||
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
|
||||
dtype=acc_type)
|
||||
for k_i in nl.affine_range(num_k_tile_per_large_tile):
|
||||
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||
|
||||
if use_causal_mask:
|
||||
# mask are used to only apply computation to the lower half of the
|
||||
# matrix, which reduce the arithmetic intensity by up to 50%
|
||||
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
||||
>= k_i * B_F_SIZE)
|
||||
else:
|
||||
multiplication_required_selection = True
|
||||
|
||||
if multiplication_required_selection:
|
||||
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum) # (128, 512)
|
||||
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||
k[:, k_i_b_f_slice],
|
||||
transpose_x=True) # (p(128), 512)
|
||||
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||
tile_mask[:, k_i_b_f_slice],
|
||||
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||
-9984.0,
|
||||
dtype=acc_type,
|
||||
)
|
||||
else:
|
||||
qk_res_buf[:, k_i_b_f_slice] = -9984.0
|
||||
|
||||
# Calculate max of the current tile
|
||||
max_local[:, k_i] = nisa.tensor_reduce(
|
||||
np.max,
|
||||
qk_res_buf[:, k_i_b_f_slice],
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
)
|
||||
|
||||
if qk_res_buffer is not None:
|
||||
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
|
||||
|
||||
max_ = nisa.tensor_reduce(
|
||||
np.max,
|
||||
max_local[:, :],
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
)
|
||||
|
||||
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=o_buffer.dtype)
|
||||
|
||||
if initialize:
|
||||
m_buffer[:, 0] = nl.copy(max_)
|
||||
m_current = max_
|
||||
else:
|
||||
m_previous = nl.copy(m_buffer[:, 0])
|
||||
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
|
||||
|
||||
m_current = m_buffer[:, 0]
|
||||
# Compute scaling factor
|
||||
alpha = nisa.activation(
|
||||
np.exp,
|
||||
m_previous,
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
)
|
||||
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
|
||||
|
||||
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
||||
|
||||
p_partial_sum = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
|
||||
dtype=acc_type,
|
||||
)
|
||||
|
||||
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
||||
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
||||
|
||||
# compute exp(qk - max)
|
||||
# Compute partial row - tile sum of exp(qk - max))
|
||||
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
|
||||
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
|
||||
np.exp,
|
||||
qk_res_buf[:, k_r_i_reduce_slice],
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
reduce_op=nl.add,
|
||||
reduce_res=p_partial_sum[:, k_r_i],
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
|
||||
|
||||
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
transpose_p_local(
|
||||
p_local_transposed=p_local_transposed,
|
||||
p_local=p_local,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
)
|
||||
|
||||
pv_psum = nl.zeros(
|
||||
(par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum,
|
||||
)
|
||||
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
pv_psum[:, :] += nl.matmul(
|
||||
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
|
||||
transpose_x=True,
|
||||
) # (128, 128) (p(Br), d)
|
||||
|
||||
if initialize:
|
||||
o_buffer[:, :] = nl.copy(pv_psum[:, :])
|
||||
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
|
||||
else:
|
||||
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
|
||||
|
||||
l_prev = l_buffer[:, 0]
|
||||
l_exp = nl.add(
|
||||
nl.exp(nl.subtract(l_prev, m_current)),
|
||||
ps,
|
||||
)
|
||||
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
|
||||
B_P_SIZE = 128
|
||||
B_D_SIZE = v_hbm_tile.shape[-1]
|
||||
loaded = nl.load(v_hbm_tile[
|
||||
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
|
||||
:,
|
||||
])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
def flash_paged_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
mask,
|
||||
softmax_scale=None,
|
||||
mixed_precision=True,
|
||||
LARGE_TILE_SZ=2048,
|
||||
return_debug_tensors=False,
|
||||
):
|
||||
"""
|
||||
Flash PagedAttention Forward Kernel.
|
||||
|
||||
IO tensor layouts:
|
||||
- query: shape (1, n_heads, d, seq_q)
|
||||
- key: shape (1, n_kv_heads, d, seq_k)
|
||||
- value: shape (1, n_kv_heads, seq_v, d)
|
||||
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (num_active_blocks, )
|
||||
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
||||
- o: shape (1, n_heads, seq_q, d)
|
||||
|
||||
- This kernel requires seq_k == seq_v
|
||||
- We use continuous batching by default, so the batch dimension is
|
||||
always 1, and different requests are concatenated along sequence
|
||||
dimension.
|
||||
- We use paged cache blocks (kv_cache) to store KV cache.
|
||||
|
||||
IO tensor dtypes:
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
block_tables (int32) and mask (int32)
|
||||
- If mixed_precision is True, then all Tensor Engine operation will be
|
||||
performed in bfloat16 and accumulation will be performed in float32.
|
||||
Otherwise the intermediates will be in the same type as the inputs.
|
||||
|
||||
Compile-time Constants:
|
||||
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
|
||||
computation reduction
|
||||
|
||||
GQA support Notes:
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
nheads
|
||||
|
||||
Example usage:
|
||||
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
|
||||
usage: `flash_fwd[b, h](q, k, v, ...)`
|
||||
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
||||
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
||||
"""
|
||||
B_F_SIZE = 512
|
||||
B_P_SIZE = 128
|
||||
b, h, d, seqlen_q = query.shape
|
||||
B_D_SIZE = d
|
||||
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||
_, num_blocks, k_h, block_size, _ = kv_cache.shape
|
||||
q_h_per_k_h = h // k_h
|
||||
assert b == 1, f"invalid batch size {b=}"
|
||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
||||
cache_shape = (2, num_blocks, k_h, block_size, d)
|
||||
assert (tuple(kv_cache.shape) == cache_shape
|
||||
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
|
||||
assert key is None or tuple(key.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
d,
|
||||
seqlen_q,
|
||||
), f"key shape {key.shape} mismatch!"
|
||||
assert value is None or tuple(value.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
seqlen_q,
|
||||
d,
|
||||
), f"value shape {value.shape} mismatch!"
|
||||
|
||||
assert (
|
||||
nl.program_ndim() == 2
|
||||
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
||||
batch_id = nl.program_id(axis=0)
|
||||
head_id = nl.program_id(axis=1)
|
||||
|
||||
(num_active_blocks, ) = block_tables.shape
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (
|
||||
LARGE_TILE_SZ % B_F_SIZE == 0
|
||||
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
|
||||
assert (context_kv_len % LARGE_TILE_SZ == 0
|
||||
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
||||
|
||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_large_tile
|
||||
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
|
||||
if seqlen_q > B_F_SIZE:
|
||||
MAX_REDUCTION_TILE = 2048
|
||||
if seqlen_q // 2 > MAX_REDUCTION_TILE:
|
||||
assert (
|
||||
seqlen_q % MAX_REDUCTION_TILE == 0
|
||||
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
|
||||
else:
|
||||
assert (seqlen_q % B_F_SIZE == 0
|
||||
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
|
||||
|
||||
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
||||
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
||||
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||
|
||||
o = nl.ndarray((b, h, seqlen_q, d),
|
||||
dtype=query.dtype,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if return_debug_tensors:
|
||||
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.shared_hbm)
|
||||
qk_res_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
block_tables_sbuf = load_block_tables(
|
||||
block_tables_hbm=block_tables,
|
||||
num_tiles=num_large_k_tile,
|
||||
num_blocks_per_tile=num_blocks_per_large_tile,
|
||||
)
|
||||
|
||||
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||
if num_blocks_per_large_tile < B_P_SIZE:
|
||||
# we checked num_blocks_per_tile is a power of 2
|
||||
assert B_P_SIZE % num_blocks_per_large_tile == 0
|
||||
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
|
||||
# We assume block_size >= block_size_tiling_factor
|
||||
assert block_size % block_size_tiling_factor == 0
|
||||
else:
|
||||
block_size_tiling_factor = 1
|
||||
tiled_block_size = block_size // block_size_tiling_factor
|
||||
|
||||
# Indirect DMA load must be placed along Partition Dimension
|
||||
block_tables_sbuf = transform_block_tables_for_indirect_load(
|
||||
block_tables_sbuf,
|
||||
block_size_tiling_factor=block_size_tiling_factor,
|
||||
num_head=k_h,
|
||||
head_id=head_id,
|
||||
)
|
||||
|
||||
# Flatten KV cache to be 3D for loading into SBUF
|
||||
new_cache_shape = (
|
||||
2,
|
||||
num_blocks * k_h * block_size_tiling_factor,
|
||||
tiled_block_size * d,
|
||||
)
|
||||
kv_cache = kv_cache.reshape(new_cache_shape)
|
||||
|
||||
# Global Flash Attention accumulators
|
||||
o_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
l_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
m_buffer = nl.zeros(
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
|
||||
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
||||
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
|
||||
cur_k_tile = nl.ndarray(
|
||||
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
load_kv_tile_from_cache(
|
||||
cur_k_tile=cur_k_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables_sbuf,
|
||||
large_k_tile_idx=large_k_tile_idx,
|
||||
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
||||
tiled_block_size=tiled_block_size,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
|
||||
_flash_attention_core(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
tile_mask=cur_mask,
|
||||
use_causal_mask=False,
|
||||
q_tile_idx=i,
|
||||
initialize=large_k_tile_idx == 0,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
)
|
||||
|
||||
# compute attention between input query, key and value
|
||||
if key is not None and value is not None:
|
||||
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
||||
LARGE_TILE_SZ = seqlen_q
|
||||
|
||||
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
loaded = nl.load(key[batch_id, head_id, :, :])
|
||||
if loaded.dtype != kernel_dtype:
|
||||
loaded = nl.copy(loaded, dtype=kernel_dtype)
|
||||
cur_k_tile[:, :] = loaded
|
||||
|
||||
v_hbm_tile = value[batch_id, head_id]
|
||||
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
load_v_tile(
|
||||
v_hbm_tile=v_hbm_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
large_tile_idx=0,
|
||||
v_i=v_i,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
_flash_attention_core(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
tile_mask=cur_mask,
|
||||
use_causal_mask=True,
|
||||
q_tile_idx=i,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
qk_res_buffer=(qk_res_buffer[i, i_q_h]
|
||||
if qk_res_buffer is not None else None),
|
||||
)
|
||||
|
||||
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
out = nl.multiply(
|
||||
o_buffer[i, i_q_h],
|
||||
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
nl.store(
|
||||
o[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
:,
|
||||
],
|
||||
out,
|
||||
)
|
||||
# maximum and summation statistics
|
||||
if return_debug_tensors:
|
||||
nl.store(
|
||||
hbm_m_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
m_buffer[i, i_q_h, :, :],
|
||||
)
|
||||
nl.store(
|
||||
hbm_l_buffer[
|
||||
batch_id,
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
l_buffer[i, i_q_h],
|
||||
)
|
||||
nl.store(
|
||||
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
||||
qk_res_buffer[batch_id, i_q_h, :, :],
|
||||
)
|
||||
|
||||
if return_debug_tensors:
|
||||
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
|
||||
return o
|
||||
|
||||
|
||||
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
|
||||
"""
|
||||
Reorder the mask to make it compatible with the flash attention kernel.
|
||||
|
||||
We vectorize KV cache read to improve DMA utilization. However, the layout
|
||||
that maximizes DMA bandwidth changes the order tokens are consumed.
|
||||
|
||||
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
|
||||
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
|
||||
each step the engine consumes a column (rather than a row) of B_P_SIZE
|
||||
tokens. Therefore, the tokens are visited in a strided way.
|
||||
|
||||
To make sure mask matches the order tokens are consumed, we need to properly
|
||||
transpose mask.
|
||||
"""
|
||||
total_query_len, total_seq_len = mask.shape
|
||||
context_kv_len = total_seq_len - total_query_len
|
||||
|
||||
B_P_SIZE = 128
|
||||
assert (LARGE_TILE_SZ
|
||||
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
|
||||
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
|
||||
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
|
||||
if tiled_block_size > 1:
|
||||
# Mask reordering is needed when tiled_block_size > 1
|
||||
device = mask.device
|
||||
mask = mask.cpu()
|
||||
context_mask = mask[:, :context_kv_len]
|
||||
context_mask = context_mask.view(
|
||||
total_query_len,
|
||||
context_kv_len // LARGE_TILE_SZ,
|
||||
num_tiled_blocks // B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
tiled_block_size,
|
||||
)
|
||||
context_mask = context_mask.transpose(3, 4).reshape(
|
||||
total_query_len, context_kv_len)
|
||||
new_mask = mask[:, context_kv_len:]
|
||||
return torch.concat([context_mask, new_mask], dim=1).to(device)
|
||||
else:
|
||||
return mask
|
||||
|
||||
|
||||
def flash_attn_varlen_nkifunc(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
block_table,
|
||||
attn_mask,
|
||||
n_kv_head=None,
|
||||
head_size=None,
|
||||
LARGE_TILE_SZ=2048,
|
||||
mixed_precision=True,
|
||||
):
|
||||
"""
|
||||
Compute flash paged attention for variable length sequences.
|
||||
|
||||
This function is a wrapper around the flash attention NKI kernel. It takes
|
||||
in the following arguments:
|
||||
- query: (1, n_heads, d, seq_q)
|
||||
- key: (1, n_kv_heads, d, seq_k)
|
||||
- value: (1, n_kv_heads, seq_v, d)
|
||||
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (n_active_blocks, )
|
||||
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
||||
|
||||
Notes:
|
||||
- attn_mask must be reordered outside using `reorder_context_mask`
|
||||
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
|
||||
for better DMA throughput
|
||||
"""
|
||||
if n_kv_head is None:
|
||||
n_kv_head = kv_cache.shape[2]
|
||||
assert kv_cache.shape[0] == 2
|
||||
assert kv_cache.shape[2] == n_kv_head
|
||||
if head_size is None:
|
||||
head_size = kv_cache.shape[-1]
|
||||
|
||||
kwargs = dict(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_table,
|
||||
mask=attn_mask,
|
||||
softmax_scale=1.0 / (head_size**0.5),
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
|
||||
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||
return o
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Writes key-value pairs to the KV cache at specified positions.
|
||||
|
||||
Args:
|
||||
key (torch.Tensor): Key tensor with shape
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
value (torch.Tensor): Value tensor with shape
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
kv_cache (torch.Tensor): Key/value cache tensor with shape
|
||||
(2, num_blocks, n_kv_head, block_size, d_head)
|
||||
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
||||
with shape (num_tokens)
|
||||
|
||||
Returns:
|
||||
None: Updates the kv_cache tensor in-place
|
||||
"""
|
||||
block_size = kv_cache.size(3)
|
||||
n_kv_head = key.size(1)
|
||||
|
||||
# Calculate indices with explicit floor division
|
||||
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_offsets = slot_mapping % block_size
|
||||
|
||||
# Create the head indices tensor
|
||||
head_indices = torch.arange(n_kv_head, device=key.device)
|
||||
|
||||
# Update caches using index_put_
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([0], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), key)
|
||||
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([1], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), value)
|
@ -54,7 +54,6 @@ SystemEnv = namedtuple(
|
||||
'is_xnnpack_available',
|
||||
'cpu_info',
|
||||
'rocm_version', # vllm specific field
|
||||
'neuron_sdk_version', # vllm specific field
|
||||
'vllm_version', # vllm specific field
|
||||
'vllm_build_flags', # vllm specific field
|
||||
'gpu_topo', # vllm specific field
|
||||
@ -275,15 +274,6 @@ def get_rocm_version(run_lambda):
|
||||
r'HIP version: (\S+)')
|
||||
|
||||
|
||||
def get_neuron_sdk_version(run_lambda):
|
||||
# Adapted from your install script
|
||||
try:
|
||||
result = run_lambda(["neuron-ls"])
|
||||
return result if result[0] == 0 else 'N/A'
|
||||
except Exception:
|
||||
return 'N/A'
|
||||
|
||||
|
||||
def get_vllm_version():
|
||||
from vllm import __version__, __version_tuple__
|
||||
|
||||
@ -306,10 +296,9 @@ def get_vllm_version():
|
||||
|
||||
def summarize_vllm_build_flags():
|
||||
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
||||
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
|
||||
return 'CUDA Archs: {}; ROCm: {}'.format(
|
||||
os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'),
|
||||
'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled',
|
||||
'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled',
|
||||
)
|
||||
|
||||
|
||||
@ -601,7 +590,6 @@ def get_env_info():
|
||||
conda_packages = get_conda_packages(run_lambda)
|
||||
|
||||
rocm_version = get_rocm_version(run_lambda)
|
||||
neuron_sdk_version = get_neuron_sdk_version(run_lambda)
|
||||
vllm_version = get_vllm_version()
|
||||
vllm_build_flags = summarize_vllm_build_flags()
|
||||
gpu_topo = get_gpu_topo(run_lambda)
|
||||
@ -635,7 +623,6 @@ def get_env_info():
|
||||
is_xnnpack_available=is_xnnpack_available(),
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
rocm_version=rocm_version,
|
||||
neuron_sdk_version=neuron_sdk_version,
|
||||
vllm_version=vllm_version,
|
||||
vllm_build_flags=vllm_build_flags,
|
||||
gpu_topo=gpu_topo,
|
||||
@ -702,7 +689,6 @@ env_info_fmt += """
|
||||
vLLM Info
|
||||
==============================
|
||||
ROCM Version : {rocm_version}
|
||||
Neuron SDK Version : {neuron_sdk_version}
|
||||
vLLM Version : {vllm_version}
|
||||
vLLM Build Flags:
|
||||
{vllm_build_flags}
|
||||
|
@ -461,11 +461,6 @@ class ModelConfig:
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP."""
|
||||
override_neuron_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""Initialize non-default neuron config or override default neuron config
|
||||
that are specific to Neuron devices, this argument will be used to
|
||||
configure the neuron config that can not be gathered from the vllm
|
||||
arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`."""
|
||||
pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||
models."""
|
||||
@ -785,10 +780,6 @@ class ModelConfig:
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
if (not current_platform.is_neuron() and self.override_neuron_config):
|
||||
raise ValueError(
|
||||
"`override_neuron_config` is only supported on Neuron.")
|
||||
|
||||
# Avoid running try_verify_and_update_config multiple times
|
||||
self.config_updated = False
|
||||
|
||||
@ -1696,13 +1687,7 @@ class ModelConfig:
|
||||
"""
|
||||
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
|
||||
True to enable cross-attention
|
||||
Neuron needs all multimodal data to be in the decoder and does not
|
||||
need to explicitly enable cross-attention
|
||||
"""
|
||||
if (current_platform.is_neuron()
|
||||
and self.hf_config.model_type == "mllama"):
|
||||
return False
|
||||
|
||||
return is_encoder_decoder(self.hf_config)
|
||||
|
||||
@property
|
||||
@ -1871,7 +1856,7 @@ class LoadConfig:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
|
||||
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"]
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config
|
||||
@ -1927,9 +1912,7 @@ class DeviceConfig:
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["neuron"]:
|
||||
self.device = torch.device("cpu")
|
||||
elif self.device_type in ["tpu"]:
|
||||
if self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
@ -3941,7 +3924,6 @@ class VllmConfig:
|
||||
f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
|
||||
f"tokenizer_mode={self.model_config.tokenizer_mode}, "
|
||||
f"revision={self.model_config.revision}, "
|
||||
f"override_neuron_config={self.model_config.override_neuron_config}, " # noqa
|
||||
f"tokenizer_revision={self.model_config.tokenizer_revision}, "
|
||||
f"trust_remote_code={self.model_config.trust_remote_code}, "
|
||||
f"dtype={self.model_config.dtype}, "
|
||||
|
@ -33,9 +33,8 @@ class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
block_size: SkipValidation[BlockSize] = None # type: ignore
|
||||
"""Size of a contiguous cache block in number of tokens. This is ignored on
|
||||
neuron devices and set to `--max-model-len`. On CUDA devices, only block
|
||||
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
|
||||
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
|
||||
only block sizes up to 32 are supported.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `Platform.check_and_update_config()` based on the current
|
||||
|
@ -377,10 +377,7 @@ class ParallelConfig:
|
||||
from vllm.executor import ray_utils
|
||||
backend: DistributedExecutorBackend = "mp"
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
if current_platform.is_neuron():
|
||||
# neuron uses single process to control multiple devices
|
||||
backend = "uni"
|
||||
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||
backend = "uni"
|
||||
elif (current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size):
|
||||
|
@ -1,20 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_neuron():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
class NeuronCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
|
||||
return xm.all_gather(x, dim=dim)
|
@ -419,8 +419,6 @@ class EngineArgs:
|
||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||
|
||||
override_neuron_config: dict[str, Any] = \
|
||||
get_field(ModelConfig, "override_neuron_config")
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||
ModelConfig.override_pooler_config
|
||||
compilation_config: CompilationConfig = \
|
||||
@ -561,8 +559,6 @@ class EngineArgs:
|
||||
help=model_kwargs["hf_token"]["help"])
|
||||
model_group.add_argument("--hf-overrides",
|
||||
**model_kwargs["hf_overrides"])
|
||||
model_group.add_argument("--override-neuron-config",
|
||||
**model_kwargs["override_neuron_config"])
|
||||
model_group.add_argument("--override-pooler-config",
|
||||
**model_kwargs["override_pooler_config"])
|
||||
model_group.add_argument("--logits-processor-pattern",
|
||||
@ -992,7 +988,6 @@ class EngineArgs:
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
generation_config=self.generation_config,
|
||||
|
@ -236,7 +236,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# ================== Installation Time Env Vars ==================
|
||||
|
||||
# Target device of vLLM, supporting [cuda (by default),
|
||||
# rocm, neuron, cpu]
|
||||
# rocm, cpu]
|
||||
"VLLM_TARGET_DEVICE":
|
||||
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(),
|
||||
|
||||
|
@ -73,11 +73,6 @@ class CustomOp(nn.Module):
|
||||
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_neuron(self, *args, **kwargs):
|
||||
# By default, we assume that Neuron ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_oot(self, *args, **kwargs):
|
||||
# By default, we assume that OOT ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
@ -105,8 +100,6 @@ class CustomOp(nn.Module):
|
||||
return self.forward_tpu
|
||||
elif current_platform.is_xpu():
|
||||
return self.forward_xpu
|
||||
elif current_platform.is_neuron():
|
||||
return self.forward_neuron
|
||||
elif current_platform.is_out_of_tree():
|
||||
return self.forward_oot
|
||||
else:
|
||||
|
@ -95,13 +95,6 @@ class SiluAndMul(CustomOp):
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
x_reshaped = x.view(-1, x.shape[-1])
|
||||
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
|
||||
result = s * x_reshaped[:, d:]
|
||||
return result.view(*x.shape[:-1], d)
|
||||
|
||||
|
||||
@CustomOp.register("mul_and_silu")
|
||||
class MulAndSilu(CustomOp):
|
||||
|
@ -26,7 +26,6 @@ QuantizationMethods = Literal[
|
||||
"bitsandbytes",
|
||||
"hqq",
|
||||
"experts_int8",
|
||||
"neuron_quant",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .rtn import RTNConfig
|
||||
@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"ptpc_fp8": PTPCFp8Config,
|
||||
"hqq": HQQMarlinConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"neuron_quant": NeuronQuantConfig,
|
||||
"ipex": IPEXConfig,
|
||||
"quark": QuarkConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
|
@ -1,76 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
|
||||
|
||||
|
||||
class AlwaysSupportedDtypes(list):
|
||||
|
||||
def __contains__(self, item):
|
||||
return True
|
||||
|
||||
|
||||
class NeuronQuantConfig(QuantizationConfig):
|
||||
"""Int8 Quantization Config class for Neuron Backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dequant_dtype: str = "f16",
|
||||
quantize_method: str = "vector_dynamic",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
|
||||
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
|
||||
raise ValueError(
|
||||
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
|
||||
f" the quantization datatype should match one of the below "
|
||||
f"types {SUPPORTED_QUANT_DTYPE_LIST}")
|
||||
self.dequant_dtype = dequant_dtype
|
||||
self.quantize_method = quantize_method
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[str]:
|
||||
# Neuron implements custom handling logic for quantization support
|
||||
return AlwaysSupportedDtypes()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"This function should not be called with Neuron Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
|
||||
quantize_method = cls.get_from_keys(config, ["quantize_method"])
|
||||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
|
||||
return cls(dequant_dtype=dequant_dtype,
|
||||
quantize_method=quantize_method)
|
||||
|
||||
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
|
||||
if find_spec("transformers_neuronx") is not None:
|
||||
return self.get_quantization_config()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Neuron Quantization is only supported through"
|
||||
" transformers_neuronx.")
|
||||
|
||||
def get_quantization_config(self):
|
||||
from transformers_neuronx.config import QuantizationConfig
|
||||
return QuantizationConfig(quant_dtype=self.quant_dtype,
|
||||
dequant_dtype=self.dequant_dtype,
|
||||
quantize_method=self.quantize_method)
|
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
|
||||
from .common import apply_rotary_emb_torch
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@ -149,87 +149,6 @@ class RotaryEmbedding(CustomOp):
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def forward_neuron(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
def _apply_rotary_emb_neuron(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
# x1 = x[..., ::2]
|
||||
|
||||
# x2 = x[..., 1::2]
|
||||
d = x.shape[-1] // 2
|
||||
x_reshaped = x.view(-1, x.shape[-1])
|
||||
x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
|
||||
x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
if key is not None:
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
|
||||
if self.rotary_dim == self.head_size:
|
||||
query = apply_rotary_emb_dispatch(query, cos, sin,
|
||||
self.is_neox_style)
|
||||
query = query.reshape(query_shape)
|
||||
if key is not None:
|
||||
key = apply_rotary_emb_dispatch(key, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = key.reshape(key_shape)
|
||||
else:
|
||||
head_size = query.shape[-1]
|
||||
query_reshaped = query.view(-1, head_size)
|
||||
query_pass = query_reshaped[:, self.rotary_dim:].view(
|
||||
*query.shape[:-1], head_size - self.rotary_dim)
|
||||
query_rot = query_reshaped[:, :self.rotary_dim].view(
|
||||
*query.shape[:-1], self.rotary_dim)
|
||||
query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass),
|
||||
dim=-1).reshape(query_shape)
|
||||
|
||||
if key is not None:
|
||||
key_reshaped = key.view(-1, head_size)
|
||||
key_pass = key_reshaped[:, self.rotary_dim:].view(
|
||||
*key.shape[:-1], head_size - self.rotary_dim)
|
||||
key_rot = key_reshaped[:, :self.rotary_dim].view(
|
||||
*key.shape[:-1], self.rotary_dim)
|
||||
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
|
@ -1,476 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading Neuron models in transformers-neuronx
|
||||
framework."""
|
||||
import ast
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput
|
||||
|
||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
"auto": "f32",
|
||||
"half": "f16",
|
||||
"float16": "f16",
|
||||
"bfloat16": "bf16",
|
||||
"float": "f32",
|
||||
"float32": "f32",
|
||||
torch.float16: "f16",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float32: "f32",
|
||||
}
|
||||
|
||||
# Models supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = {
|
||||
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||
"LlamaForSampling", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||
"MistralForSampling", "MistralForCausalLM")
|
||||
}
|
||||
|
||||
|
||||
class NeuronCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
on_device_sampling_disabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
|
||||
self.on_device_sampling_disabled = on_device_sampling_disabled
|
||||
if self.on_device_sampling_disabled:
|
||||
# Use default sampler
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits = self.model(input_ids,
|
||||
cache_ids=positions,
|
||||
start_ids=input_block_ids)
|
||||
return logits
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
|
||||
if self.on_device_sampling_disabled:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
# On-device sampling outputs the token ids directly.
|
||||
sampled_token_ids = logits.flatten()
|
||||
next_tokens = []
|
||||
sample_idx = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
samples = []
|
||||
for seq_id in seq_group.seq_ids:
|
||||
token_id = sampled_token_ids[sample_idx].item()
|
||||
samples.append(
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)}))
|
||||
sample_idx += 1
|
||||
next_tokens.append(
|
||||
CompletionSequenceGroupOutput(samples=samples,
|
||||
prompt_logprobs=None))
|
||||
|
||||
return SamplerOutput(outputs=next_tokens)
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
|
||||
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
|
||||
**kwargs)
|
||||
self.model.to_neuron()
|
||||
|
||||
|
||||
class NeuronSpeculationCausalLM(nn.Module):
|
||||
"""A Neuron-optimized causal language model with speculative decoding."""
|
||||
|
||||
SPECULATION_TERMINATION_ID = -1
|
||||
|
||||
def __init__(self, speculation_model) -> None:
|
||||
super().__init__()
|
||||
self.model = speculation_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
tokens, counts = self.model.speculative_iteration(
|
||||
input_ids, positions, input_block_ids)
|
||||
|
||||
# Mark the end of accepted speculative tokens for each sequence with the
|
||||
# speculation termination id.
|
||||
batch_size, steps = tokens.shape
|
||||
mask = torch.arange(steps).expand(batch_size, -1) >= counts
|
||||
tokens[mask] = self.SPECULATION_TERMINATION_ID
|
||||
|
||||
return tokens
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
batch_size, num_steps = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
for seq_id in sg.seq_ids
|
||||
]
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
accepted_token_ids_by_step = logits.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
|
||||
sampler_output_list = []
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == self.SPECULATION_TERMINATION_ID
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
step_output_token_ids = []
|
||||
for sequence_index in range(batch_size):
|
||||
token_id = accepted_token_ids_by_step[step_index][
|
||||
sequence_index]
|
||||
step_output_token_ids.append(
|
||||
CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)})
|
||||
],
|
||||
prompt_logprobs=None))
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
return sampler_output_list
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _NEURON_SUPPORTED_MODELS:
|
||||
return arch
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported on Neuron "
|
||||
f"for now. Supported architectures: "
|
||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||
|
||||
|
||||
def _get_buckets(env: str, default_value: list[int]) -> list[int]:
|
||||
env_value = os.getenv(env)
|
||||
if env_value is None:
|
||||
return default_value
|
||||
buckets_remove_empty = filter(
|
||||
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
|
||||
buckets_int = map(int, buckets_remove_empty)
|
||||
buckets_list = list(buckets_int)
|
||||
return buckets_list
|
||||
|
||||
|
||||
def _get_default_neuron_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig):
|
||||
"""Generate a neuron config based on vllm config args."""
|
||||
from transformers_neuronx.config import ContinuousBatchingConfig
|
||||
from transformers_neuronx.constants import LAYOUT_BSH
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
quant_config = dict(
|
||||
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
quantize_method="vector_dynamic")
|
||||
neuron_quantization_config_builder = lambda quant: get_quantization_config(
|
||||
quant).from_config(quant_config).get_quant_method(None, "")
|
||||
# TODO: Add Paged attention config to the default neuron arguments.
|
||||
default_neuron_args = dict(
|
||||
collectives_layout=LAYOUT_BSH,
|
||||
attention_layout=LAYOUT_BSH,
|
||||
fuse_qkv=True,
|
||||
quant=neuron_quantization_config_builder(model_config.quantization)
|
||||
if model_config.quantization else None,
|
||||
continuous_batching=continuous_batching_config,
|
||||
weight_tiling=bool(model_config.quantization),
|
||||
on_device_generation=_get_neuron_on_device_generation_config(
|
||||
model_config))
|
||||
return default_neuron_args
|
||||
|
||||
|
||||
def _get_default_neuron_config_for_speculation(
|
||||
model_config: ModelConfig, parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig):
|
||||
"""Generate a neuron config for speculative decoding based on
|
||||
vllm config args."""
|
||||
from transformers_neuronx.config import ContinuousBatchingConfig
|
||||
from transformers_neuronx.constants import LAYOUT_BSH
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
|
||||
default_neuron_args = dict(collectives_layout=LAYOUT_BSH,
|
||||
attention_layout=LAYOUT_BSH,
|
||||
fuse_qkv=True,
|
||||
on_device_embedding=True,
|
||||
continuous_batching=continuous_batching_config,
|
||||
on_device_generation=copy.deepcopy(
|
||||
model_config.neuron_sampling_params))
|
||||
return default_neuron_args
|
||||
|
||||
|
||||
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
|
||||
if not _is_neuron_on_device_sampling_disabled(model_config):
|
||||
return copy.deepcopy(model_config.neuron_sampling_params)
|
||||
return None
|
||||
|
||||
|
||||
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
|
||||
return not getattr(model_config, "neuron_sampling_params", None)
|
||||
|
||||
|
||||
def _get_neuron_config_after_override(default_neuron_config,
|
||||
overridden_neuron_config):
|
||||
from transformers_neuronx.config import (ContinuousBatchingConfig,
|
||||
GenerationConfig,
|
||||
KVCacheQuantizationConfig,
|
||||
NeuronConfig, QuantizationConfig,
|
||||
SparseAttnConfig)
|
||||
|
||||
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
|
||||
if sparse_attn:
|
||||
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
|
||||
**sparse_attn)
|
||||
|
||||
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
|
||||
if kv_cache_quant:
|
||||
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
|
||||
**kv_cache_quant)
|
||||
|
||||
continuous_batching = overridden_neuron_config.pop("continuous_batching",
|
||||
{})
|
||||
if continuous_batching:
|
||||
overridden_neuron_config[
|
||||
"continuous_batching"] = ContinuousBatchingConfig(
|
||||
**continuous_batching)
|
||||
|
||||
quant = overridden_neuron_config.pop("quant", {})
|
||||
if quant:
|
||||
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
|
||||
|
||||
on_device_generation = overridden_neuron_config.pop(
|
||||
"on_device_generation", {})
|
||||
if on_device_generation:
|
||||
overridden_neuron_config["on_device_generation"] = GenerationConfig(
|
||||
**on_device_generation)
|
||||
default_neuron_config.update(overridden_neuron_config)
|
||||
return NeuronConfig(**default_neuron_config)
|
||||
|
||||
|
||||
def get_neuron_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
"""Initializes a neuron-optimized model for inference."""
|
||||
# Create a model instance.
|
||||
model = NeuronCausalLM(
|
||||
model_config.hf_config,
|
||||
_is_neuron_on_device_sampling_disabled(model_config))
|
||||
|
||||
default_neuron_config_args = _get_default_neuron_config(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
||||
model.load_weights(model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_neuron_speculation_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Initializes a neuron-optimized speculation model for inference.
|
||||
|
||||
This method is only applicable for speculation with a standalone draft model
|
||||
"""
|
||||
from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder
|
||||
|
||||
# For Eagle SD, we need to pass in additional parameters in neuron config.
|
||||
is_eagle = getattr(speculation_config.draft_model_config.hf_config,
|
||||
"is_eagle", False)
|
||||
|
||||
# Create target model instance.
|
||||
target_model = NeuronCausalLM(model_config.hf_config)
|
||||
|
||||
default_neuron_config_args = _get_default_neuron_config_for_speculation(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
if is_eagle:
|
||||
default_neuron_config_args['is_eagle_target'] = True
|
||||
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
||||
target_model.load_weights(
|
||||
model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
target_model.eval()
|
||||
|
||||
# Create draft model instance.
|
||||
draft_model = NeuronCausalLM(
|
||||
speculation_config.draft_model_config.hf_config)
|
||||
|
||||
default_draft_neuron_config_args = (
|
||||
_get_default_neuron_config_for_speculation(
|
||||
speculation_config.draft_model_config, parallel_config,
|
||||
scheduler_config))
|
||||
if is_eagle:
|
||||
default_draft_neuron_config_args['is_eagle_draft'] = True
|
||||
default_draft_neuron_config_args['has_pre_attention_norm'] = False
|
||||
|
||||
draft_neuron_config = _get_neuron_config_after_override(
|
||||
default_draft_neuron_config_args,
|
||||
speculation_config.draft_model_config.override_neuron_config)
|
||||
|
||||
draft_model.load_weights(speculation_config.draft_model_config.model,
|
||||
tp_degree=speculation_config.
|
||||
draft_parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[
|
||||
speculation_config.draft_model_config.dtype],
|
||||
neuron_config=draft_neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
draft_model.eval()
|
||||
|
||||
num_speculative_tokens = speculation_config.num_speculative_tokens
|
||||
# Create speculation model instance.
|
||||
speculation_model = FusedSpeculativeDecoder(draft_model.model,
|
||||
target_model.model,
|
||||
num_speculative_tokens)
|
||||
speculation_model.to_neuron()
|
||||
|
||||
return NeuronSpeculationCausalLM(speculation_model)
|
||||
|
||||
|
||||
def get_neuron_eagle_speculation_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Initializes a neuron-optimized EAGLE speculation model for inference."""
|
||||
from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder
|
||||
|
||||
# Create target model instance.
|
||||
target_model = NeuronCausalLM(model_config.hf_config)
|
||||
|
||||
default_neuron_config_args = _get_default_neuron_config_for_speculation(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
default_neuron_config_args['is_eagle_target'] = True
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
||||
target_model.load_weights(
|
||||
model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
target_model.eval()
|
||||
|
||||
# Create draft model instance.
|
||||
draft_model = NeuronCausalLM(
|
||||
speculation_config.draft_model_config.hf_config)
|
||||
|
||||
default_draft_neuron_config_args = (
|
||||
_get_default_neuron_config_for_speculation(
|
||||
speculation_config.draft_model_config, parallel_config,
|
||||
scheduler_config))
|
||||
default_draft_neuron_config_args['is_eagle_draft'] = True
|
||||
default_draft_neuron_config_args['has_pre_attention_norm'] = False
|
||||
draft_neuron_config = _get_neuron_config_after_override(
|
||||
default_draft_neuron_config_args,
|
||||
speculation_config.draft_model_config.override_neuron_config)
|
||||
|
||||
draft_model.load_weights(speculation_config.draft_model_config.model,
|
||||
tp_degree=speculation_config.
|
||||
draft_parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[
|
||||
speculation_config.draft_model_config.dtype],
|
||||
neuron_config=draft_neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
draft_model.eval()
|
||||
|
||||
token_tree: dict[int, list[int]] = ast.literal_eval(
|
||||
speculation_config.speculative_token_tree)
|
||||
|
||||
speculation_model = EagleSpeculativeDecoder(draft_model.model,
|
||||
target_model.model,
|
||||
token_tree=token_tree)
|
||||
speculation_model.to_neuron()
|
||||
|
||||
return NeuronSpeculationCausalLM(speculation_model)
|
@ -1,685 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading Neuron models in
|
||||
neuronx-distributed-inference framework."""
|
||||
# Disabling yapf because yapf and isort have conflicts for the below imports
|
||||
# yapf: disable
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from neuronx_distributed_inference.models.config import (
|
||||
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
|
||||
from neuronx_distributed_inference.models.mllama.utils import (
|
||||
create_vision_mask)
|
||||
from neuronx_distributed_inference.modules.lora_serving import (
|
||||
LoraServingConfig)
|
||||
from neuronx_distributed_inference.utils.hf_adapter import (
|
||||
load_pretrained_config)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
|
||||
|
||||
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput
|
||||
|
||||
# yapf: enable
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
"auto": "float32",
|
||||
"half": "float16",
|
||||
"float16": "float16",
|
||||
"bfloat16": "bfloat16",
|
||||
"float": "float32",
|
||||
"float32": "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.float32: "float32",
|
||||
}
|
||||
|
||||
# Models supported by Neuronx distributed for inference.
|
||||
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
|
||||
"LlamaForCausalLM":
|
||||
("neuronx_distributed_inference.models.llama.modeling_llama",
|
||||
"NeuronLlamaForCausalLM"),
|
||||
"MistralForCausalLM":
|
||||
("neuronx_distributed_inference.models.llama.modeling_llama",
|
||||
"NeuronLlamaForCausalLM"),
|
||||
"DbrxForCausalLM":
|
||||
("neuronx_distributed_inference.models.dbrx.modeling_dbrx",
|
||||
"NeuronDbrxForCausalLM"),
|
||||
"MixtralForCausalLM":
|
||||
("neuronx_distributed_inference.models.mixtral.modeling_mixtral",
|
||||
"NeuronMixtralForCausalLM"),
|
||||
"MllamaForConditionalGeneration":
|
||||
("neuronx_distributed_inference.models.mllama.modeling_mllama",
|
||||
"NeuronMllamaForCausalLM"),
|
||||
}
|
||||
|
||||
|
||||
class NeuronCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
sampling_params: torch.Tensor,
|
||||
prev_hidden: Optional[torch.Tensor] = None,
|
||||
adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
output = self.model(input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=positions,
|
||||
seq_ids=sorted_input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
prev_hidden=prev_hidden,
|
||||
adapter_ids=adapter_ids)
|
||||
# on-device sampling
|
||||
if self.config.neuron_config.on_device_sampling_config:
|
||||
output = output.hidden_states
|
||||
else:
|
||||
output = output.logits[:, -1, :]
|
||||
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
if input_block_ids.shape[0] != 1:
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
# on-device sampling
|
||||
if self.config.neuron_config.on_device_sampling_config:
|
||||
batch_size = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
for seq_id in sg.seq_ids
|
||||
]
|
||||
assert len(seq_ids) == list(batch_size)[0], "batch size mismatch"
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
accepted_token_ids_by_step = logits.flatten()
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
|
||||
step_output_token_ids = []
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
token_id = accepted_token_ids_by_step[i]
|
||||
step_output_token_ids.append(
|
||||
CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)})
|
||||
],
|
||||
prompt_logprobs=None))
|
||||
return SamplerOutput(outputs=step_output_token_ids)
|
||||
else:
|
||||
return self.sampler(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
|
||||
**kwargs['neuron_config'])
|
||||
self.config.neuron_config = neuron_config
|
||||
config = neuronx_model_cls.get_config_cls()(
|
||||
neuron_config,
|
||||
load_config=load_pretrained_config(model_name_or_path))
|
||||
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
|
||||
usedforsecurity=False).hexdigest()
|
||||
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
|
||||
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
|
||||
elif os.path.exists(model_name_or_path):
|
||||
compiled_model_path = os.path.join(model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
else:
|
||||
compiled_model_path = os.path.join("local-models",
|
||||
model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
try:
|
||||
self.model = neuronx_model_cls(compiled_model_path)
|
||||
override_neuron_config = kwargs["override_neuron_config"]
|
||||
for k, v in override_neuron_config.items():
|
||||
setattr(self.model.config.neuron_config, k, v)
|
||||
self.model.load(compiled_model_path)
|
||||
return
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
logger.warning("Exception: %s", e)
|
||||
logger.warning("Failed to load the model from %s, Recompiling...",
|
||||
compiled_model_path)
|
||||
if not os.path.exists(model_name_or_path):
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
saved_path = os.path.join("local-models", model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
model_name_or_path = saved_path
|
||||
self.model = neuronx_model_cls(model_name_or_path, config)
|
||||
self.model.compile(compiled_model_path)
|
||||
self.model.load(compiled_model_path)
|
||||
|
||||
|
||||
class NeuronMllamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
on_device_sampling_disabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
# has_image is the only multimodal input that is used in
|
||||
# token-generation
|
||||
# This is a cache (on CPU) that saves has_image data per sequence id
|
||||
# The number of entries in this cache is <= Batch-Size
|
||||
self.has_image_cache: dict[int, torch.Tensor] = {}
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.get_text_config().vocab_size, logits_as_input=True)
|
||||
|
||||
self.on_device_sampling_disabled = on_device_sampling_disabled
|
||||
if self.on_device_sampling_disabled:
|
||||
# Use default sampler
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
self.is_reorder_needed: bool = True
|
||||
|
||||
def read_from_has_image_cache(self, seq_ids: torch.Tensor):
|
||||
has_image_list = []
|
||||
for index in range(len(seq_ids)):
|
||||
seq_id = seq_ids[index].item()
|
||||
if seq_id in self.has_image_cache:
|
||||
has_image_list.append(self.has_image_cache[seq_id])
|
||||
else:
|
||||
has_image_list.append(torch.tensor([0]))
|
||||
return torch.tensor(has_image_list)
|
||||
|
||||
def write_to_has_image_cache(self, seq_ids: torch.Tensor,
|
||||
has_image: torch.Tensor):
|
||||
for index in range(len(seq_ids)):
|
||||
seq_id = seq_ids[index].item()
|
||||
if index < len(has_image):
|
||||
self.has_image_cache[seq_id] = has_image[index]
|
||||
else:
|
||||
self.has_image_cache[seq_id] = torch.zeros(1)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
|
||||
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
|
||||
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
|
||||
|
||||
# We update the has_image cache during prefill
|
||||
# and read the has_image cache during decode
|
||||
if input_ids.shape[-1] > 1: # prefill
|
||||
self.write_to_has_image_cache(seq_ids, has_image)
|
||||
else:
|
||||
has_image = self.read_from_has_image_cache(seq_ids)
|
||||
bs = input_ids.shape[0]
|
||||
num_chunks = torch.zeros((bs, 1))
|
||||
aspect_ratios = torch.zeros((bs, 1, 2))
|
||||
|
||||
input_block_ids = seq_ids
|
||||
origin_input_block_ids = seq_ids
|
||||
if self.is_reorder_needed:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
pixel_values = torch.index_select(pixel_values, 0, sorted_indices)
|
||||
aspect_ratios = torch.index_select(aspect_ratios, 0,
|
||||
sorted_indices)
|
||||
num_chunks = torch.index_select(num_chunks, 0, sorted_indices)
|
||||
has_image = torch.index_select(has_image, 0, sorted_indices)
|
||||
|
||||
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
|
||||
output = self.model(
|
||||
input_ids.to(torch.int32),
|
||||
attention_mask=None,
|
||||
position_ids=positions.to(torch.int32),
|
||||
seq_ids=seq_ids.flatten().to(torch.int32),
|
||||
pixel_values=pixel_values.to(
|
||||
self.config.vision_config.torch_dtype),
|
||||
aspect_ratios=aspect_ratios.to(torch.int32),
|
||||
vision_mask=self.vision_mask.to(torch.int32),
|
||||
sampling_params=sampling_params,
|
||||
num_chunks=num_chunks.to(torch.int32),
|
||||
has_image=has_image.to(torch.int32),
|
||||
)
|
||||
if self.config.neuron_config.on_device_sampling_config:
|
||||
output = output.hidden_states
|
||||
else:
|
||||
output = output.logits[:, -1, :]
|
||||
|
||||
if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1:
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, hidden_states, sampling_metadata):
|
||||
if not self.on_device_sampling_disabled:
|
||||
with torch.profiler.record_function("sample"):
|
||||
hidden_states = hidden_states.flatten()
|
||||
res = []
|
||||
sample_idx = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
samples = []
|
||||
for seq_id in seq_ids:
|
||||
token_id = hidden_states[sample_idx].item()
|
||||
samples.append(
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)}))
|
||||
sample_idx += 1
|
||||
res.append(
|
||||
CompletionSequenceGroupOutput(samples=samples,
|
||||
prompt_logprobs=None))
|
||||
next_tokens = SamplerOutput(outputs=res)
|
||||
else:
|
||||
next_tokens = self.sampler(None, hidden_states, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
|
||||
**kwargs['neuron_config'])
|
||||
self.config.neuron_config = neuron_config
|
||||
logger.info("neuron_config buckets: %s",
|
||||
self.config.neuron_config.buckets)
|
||||
config = neuronx_model_cls.get_config_cls()(
|
||||
neuron_config,
|
||||
load_config=load_pretrained_config(model_name_or_path))
|
||||
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
|
||||
usedforsecurity=False).hexdigest()
|
||||
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
|
||||
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
|
||||
elif os.path.exists(model_name_or_path):
|
||||
compiled_model_path = os.path.join(model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
else:
|
||||
compiled_model_path = os.path.join("local-models",
|
||||
model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
try:
|
||||
self.model = neuronx_model_cls(compiled_model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.vision_token_id = tokenizer(
|
||||
"<|image|>", add_special_tokens=False).input_ids[0]
|
||||
self.model.load(compiled_model_path)
|
||||
return
|
||||
except (FileNotFoundError, ValueError):
|
||||
logger.warning("Failed to load the model from %s, Recompiling...",
|
||||
compiled_model_path)
|
||||
if not os.path.exists(model_name_or_path):
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
saved_path = os.path.join("local-models", model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
model_name_or_path = saved_path
|
||||
self.model = neuronx_model_cls(model_name_or_path, config)
|
||||
|
||||
logger.info("\nCompiling and saving model to %s", model_name_or_path)
|
||||
|
||||
p = multiprocessing.Process(target=compile_model,
|
||||
args=(self, compiled_model_path))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer.save_pretrained(compiled_model_path)
|
||||
logger.info("Successfully compiled and saved the model in %s",
|
||||
compiled_model_path)
|
||||
|
||||
# Read "<|image|>" token_id from the tokenizer
|
||||
self.vision_token_id = tokenizer("<|image|>",
|
||||
add_special_tokens=False).input_ids[0]
|
||||
logger.info("\nLoading model from compiled checkpoint...")
|
||||
self.model.load(compiled_model_path)
|
||||
|
||||
|
||||
def compile_model(neuron_model, traced_model_path):
|
||||
neuron_model.model.compile(traced_model_path)
|
||||
|
||||
|
||||
class NeuronSpeculationCausalLM(nn.Module):
|
||||
"""A Neuron-optimized causal language model with speculative decoding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
sampling_params: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
|
||||
output = self.model(input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=positions,
|
||||
seq_ids=sorted_input_block_ids,
|
||||
sampling_params=sampling_params)
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
|
||||
# CTX encoding
|
||||
if (positions[:, 0]).sum().item() == 0:
|
||||
output = output.fused_outputs[0][:, 0:1]
|
||||
if input_block_ids.shape[0] != 1:
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
return output
|
||||
|
||||
# Fused Spec (Generation)
|
||||
accepted_tokens_with_padding = output.fused_outputs[0]
|
||||
next_pos_ids = output.fused_outputs[-1]
|
||||
generated_token_counts = next_pos_ids - positions
|
||||
|
||||
assert torch.any(generated_token_counts == 0).item() is False, \
|
||||
"NxDI model generated no output for one or more sequences."
|
||||
|
||||
batch_size, steps = accepted_tokens_with_padding.shape
|
||||
mask = torch.arange(steps).expand(batch_size,
|
||||
-1) >= generated_token_counts
|
||||
accepted_tokens_with_padding[mask] = -1
|
||||
|
||||
if input_block_ids.shape[0] != 1:
|
||||
accepted_tokens_with_padding = torch.index_select(
|
||||
accepted_tokens_with_padding, 0, restored_indices)
|
||||
|
||||
return accepted_tokens_with_padding
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
batch_size, num_steps = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
for seq_id in sg.seq_ids
|
||||
]
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
accepted_token_ids_by_step = logits.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
|
||||
sampler_output_list = []
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
step_output_token_ids = []
|
||||
for sequence_index in range(batch_size):
|
||||
token_id = accepted_token_ids_by_step[step_index][
|
||||
sequence_index]
|
||||
step_output_token_ids.append(
|
||||
CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)})
|
||||
],
|
||||
prompt_logprobs=None))
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
return sampler_output_list
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
draft_model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
|
||||
**kwargs['neuron_config'])
|
||||
config = neuronx_model_cls.get_config_cls()(
|
||||
neuron_config,
|
||||
load_config=load_pretrained_config(model_name_or_path))
|
||||
|
||||
draft_neuron_config = copy.deepcopy(config.neuron_config)
|
||||
if not config.neuron_config.enable_eagle_speculation:
|
||||
draft_neuron_config.speculation_length = 0
|
||||
draft_neuron_config.trace_tokengen_model = True
|
||||
draft_neuron_config.enable_fused_speculation = False
|
||||
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
|
||||
None):
|
||||
draft_neuron_config.modules_to_not_convert = (
|
||||
draft_neuron_config.draft_model_modules_to_not_convert)
|
||||
if config.neuron_config.enable_eagle_speculation:
|
||||
draft_neuron_config.is_eagle_draft = True
|
||||
draft_neuron_config.sequence_parallel_enabled = False
|
||||
draft_config = neuronx_model_cls.get_config_cls()(
|
||||
draft_neuron_config,
|
||||
load_config=load_pretrained_config(draft_model_name_or_path))
|
||||
fused_spec_config = (FusedSpecNeuronConfig(
|
||||
neuronx_model_cls._model_cls,
|
||||
draft_config=draft_config,
|
||||
draft_model_path=draft_model_name_or_path))
|
||||
config.fused_spec_config = fused_spec_config
|
||||
self.config.neuron_config = neuron_config
|
||||
|
||||
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
|
||||
usedforsecurity=False).hexdigest()
|
||||
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
|
||||
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
|
||||
elif os.path.exists(model_name_or_path):
|
||||
compiled_model_path = os.path.join(model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
else:
|
||||
compiled_model_path = os.path.join("local-models",
|
||||
model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
try:
|
||||
self.model = neuronx_model_cls(compiled_model_path)
|
||||
override_neuron_config = kwargs["override_neuron_config"]
|
||||
for k, v in override_neuron_config.items():
|
||||
setattr(self.model.config.neuron_config, k, v)
|
||||
self.model.load(compiled_model_path)
|
||||
return
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
logger.warning("Exception: %s", e)
|
||||
logger.warning("Failed to load the model from %s Recompiling...",
|
||||
compiled_model_path)
|
||||
if not os.path.exists(model_name_or_path):
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
saved_path = os.path.join("local-models", model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
model_name_or_path = saved_path
|
||||
if not os.path.exists(draft_model_name_or_path):
|
||||
if draft_model_name_or_path != model_name_or_path:
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
draft_model_name_or_path)
|
||||
saved_path = os.path.join("local-models",
|
||||
draft_model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
draft_model_name_or_path = saved_path
|
||||
else:
|
||||
draft_model_name_or_path = model_name_or_path
|
||||
config.fused_spec_config.draft_model_path = draft_model_name_or_path
|
||||
self.model = neuronx_model_cls(model_name_or_path, config)
|
||||
self.model.compile(compiled_model_path)
|
||||
self.model.load(compiled_model_path)
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _NEURON_SUPPORTED_MODELS:
|
||||
return arch
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported on Neuron "
|
||||
f"for now. Supported architectures: "
|
||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||
|
||||
|
||||
def _get_default_neuron_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_serving_config: LoraServingConfig):
|
||||
"""Generate a neuron config based on vllm config args."""
|
||||
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
|
||||
deterministic=False)
|
||||
batch_size = scheduler_config.max_num_seqs
|
||||
|
||||
neuron_config = dict(
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
ctx_batch_size=1,
|
||||
batch_size=batch_size,
|
||||
max_context_length=scheduler_config.max_model_len,
|
||||
seq_len=scheduler_config.max_model_len,
|
||||
enable_bucketing=True,
|
||||
is_continuous_batching=True,
|
||||
quantized=False,
|
||||
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
padding_side="right",
|
||||
on_device_sampling_config=on_device_sampling_config,
|
||||
sequence_parallel_enabled=True,
|
||||
lora_serving_config=lora_serving_config)
|
||||
return neuron_config
|
||||
|
||||
|
||||
def _get_default_speculation_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Generate a neuron config for speculative decoding based on vllm config
|
||||
args."""
|
||||
neuron_config = dict(
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
ctx_batch_size=1,
|
||||
batch_size=scheduler_config.max_num_seqs,
|
||||
max_context_length=scheduler_config.max_model_len,
|
||||
seq_len=scheduler_config.max_model_len,
|
||||
speculation_length=speculation_config.num_speculative_tokens,
|
||||
trace_tokengen_model=False,
|
||||
enable_fused_speculation=True,
|
||||
enable_bucketing=True,
|
||||
is_continuous_batching=True,
|
||||
quantized=False,
|
||||
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
on_device_sampling_config=dict(
|
||||
top_k=1,
|
||||
do_sample=False,
|
||||
))
|
||||
return neuron_config
|
||||
|
||||
|
||||
def _get_neuron_config_after_override(default_neuron_config,
|
||||
overridden_neuron_config):
|
||||
"""Update default neuron config values with override args"""
|
||||
overridden_neuron_config = overridden_neuron_config or {}
|
||||
default_neuron_config.update(overridden_neuron_config)
|
||||
return default_neuron_config
|
||||
|
||||
|
||||
def get_neuron_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_serving_config: LoraServingConfig) -> nn.Module:
|
||||
"""Initializes a neuron-optimized model for inference."""
|
||||
model_arch = _get_model_architecture(model_config.hf_config)
|
||||
if model_arch == "MllamaForConditionalGeneration":
|
||||
model = NeuronMllamaForCausalLM(model_config.hf_config)
|
||||
else:
|
||||
model = NeuronCausalLM(model_config.hf_config)
|
||||
default_neuron_config_args = _get_default_neuron_config(
|
||||
model_config, parallel_config, scheduler_config, lora_serving_config)
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
override_neuron_config = model_config.override_neuron_config
|
||||
model.load_weights(model_config.model,
|
||||
neuron_config=neuron_config,
|
||||
override_neuron_config=override_neuron_config)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_neuron_speculation_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Initializes a neuron-optimized speculation model for inference.
|
||||
|
||||
This model handles speculation using both a draft model and an EAGLE draft.
|
||||
"""
|
||||
model = NeuronSpeculationCausalLM(model_config.hf_config)
|
||||
default_neuron_config_args = _get_default_speculation_config(
|
||||
model_config, parallel_config, scheduler_config, speculation_config)
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
override_neuron_config = model_config.override_neuron_config
|
||||
model.load_weights(model_config.model,
|
||||
speculation_config.draft_model_config.model,
|
||||
neuron_config=neuron_config,
|
||||
override_neuron_config=override_neuron_config)
|
||||
return model.eval()
|
@ -169,37 +169,12 @@ def cpu_platform_plugin() -> Optional[str]:
|
||||
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
||||
|
||||
|
||||
def neuron_platform_plugin() -> Optional[str]:
|
||||
tnx_installed = False
|
||||
nxd_installed = False
|
||||
logger.debug("Checking if Neuron platform is available.")
|
||||
try:
|
||||
import transformers_neuronx # noqa: F401
|
||||
tnx_installed = True
|
||||
logger.debug("Confirmed Neuron platform is available because"
|
||||
" transformers_neuronx is found.")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import neuronx_distributed_inference # noqa: F401
|
||||
nxd_installed = True
|
||||
logger.debug("Confirmed Neuron platform is available because"
|
||||
" neuronx_distributed_inference is found.")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
is_neuron = tnx_installed or nxd_installed
|
||||
return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
|
||||
|
||||
|
||||
builtin_platform_plugins = {
|
||||
'tpu': tpu_platform_plugin,
|
||||
'cuda': cuda_platform_plugin,
|
||||
'rocm': rocm_platform_plugin,
|
||||
'xpu': xpu_platform_plugin,
|
||||
'cpu': cpu_platform_plugin,
|
||||
'neuron': neuron_platform_plugin,
|
||||
}
|
||||
|
||||
|
||||
|
@ -73,7 +73,6 @@ class PlatformEnum(enum.Enum):
|
||||
TPU = enum.auto()
|
||||
XPU = enum.auto()
|
||||
CPU = enum.auto()
|
||||
NEURON = enum.auto()
|
||||
OOT = enum.auto()
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
@ -164,9 +163,6 @@ class Platform:
|
||||
def is_cpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.CPU
|
||||
|
||||
def is_neuron(self) -> bool:
|
||||
return self._enum == PlatformEnum.NEURON
|
||||
|
||||
def is_out_of_tree(self) -> bool:
|
||||
return self._enum == PlatformEnum.OOT
|
||||
|
||||
|
@ -1,151 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NeuronFramework(enum.Enum):
|
||||
TRANSFORMERS_NEURONX = "transformers-neuronx"
|
||||
NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference"
|
||||
|
||||
|
||||
class NeuronPlatform(Platform):
|
||||
_enum = PlatformEnum.NEURON
|
||||
device_name: str = "neuron"
|
||||
device_type: str = "neuron"
|
||||
ray_device_key: str = "neuron_cores"
|
||||
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
|
||||
dist_backend: str = "gloo"
|
||||
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "neuron"
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.neuron_worker.NeuronWorker"
|
||||
|
||||
if parallel_config.world_size > 1:
|
||||
parallel_config.distributed_executor_backend = "uni"
|
||||
|
||||
if vllm_config.cache_config and vllm_config.model_config:
|
||||
# neuron needs block_size = max_model_len
|
||||
vllm_config.cache_config.block_size = \
|
||||
vllm_config.model_config.max_model_len # type: ignore
|
||||
|
||||
if vllm_config.model_config and vllm_config.model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
logger.warning("Pin memory is not supported on Neuron.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
if envs.VLLM_USE_V1:
|
||||
return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa
|
||||
else:
|
||||
return Platform.get_device_communicator_cls()
|
||||
|
||||
@classmethod
|
||||
def use_all_gather(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def is_neuronx_distributed_inference(cls) -> bool:
|
||||
try:
|
||||
import neuronx_distributed_inference
|
||||
except ImportError:
|
||||
neuronx_distributed_inference = None
|
||||
return neuronx_distributed_inference is not None
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def is_transformers_neuronx(cls) -> bool:
|
||||
try:
|
||||
import transformers_neuronx
|
||||
except ImportError:
|
||||
transformers_neuronx = None
|
||||
return transformers_neuronx is not None
|
||||
|
||||
def get_neuron_framework_to_use(self):
|
||||
"""Return the specified framework if corresponding installations are
|
||||
available.
|
||||
|
||||
If no framework is specified, use neuronx-distributed-inference by
|
||||
default.
|
||||
If that's unavailable, check and switch to transformers-neuronx.
|
||||
"""
|
||||
if not self.is_neuron():
|
||||
raise AssertionError(
|
||||
f"Neuron Framework unavailable for platform: {self}")
|
||||
|
||||
tnx_installed = self.is_transformers_neuronx()
|
||||
nxd_installed = self.is_neuronx_distributed_inference()
|
||||
|
||||
specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK")
|
||||
tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value
|
||||
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value
|
||||
if specified_framework == tnx_framework and tnx_installed:
|
||||
return self.TRANSFORMERS_NEURONX
|
||||
|
||||
if ((specified_framework == nxd_framework and nxd_installed)
|
||||
or (specified_framework is None and nxd_installed)):
|
||||
return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
|
||||
|
||||
if specified_framework is None and tnx_installed:
|
||||
return NeuronFramework.TRANSFORMERS_NEURONX
|
||||
|
||||
return None
|
||||
|
||||
def use_neuronx_distributed(self):
|
||||
"""
|
||||
Return True if the framework determined in get_neuron_framework_to_use()
|
||||
is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This
|
||||
is used to select the Neuron model framework and framework-specific
|
||||
configuration to apply during model compilation.
|
||||
"""
|
||||
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
|
||||
return self.get_neuron_framework_to_use() == nxd_framework
|
||||
|
||||
def use_transformers_neuronx(self):
|
||||
"""
|
||||
Return True if the framework determined in get_neuron_framework_to_use()
|
||||
is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used
|
||||
to select the Neuron model framework and framework-specific
|
||||
configuration to apply during model compilation.
|
||||
"""
|
||||
return self.get_neuron_framework_to_use(
|
||||
) == NeuronFramework.TRANSFORMERS_NEURONX
|
@ -1,455 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import DeviceConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForNeuron(ModelRunnerInputBase):
|
||||
"""
|
||||
Used by the NeuronModelRunner.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
input_block_ids: Optional[torch.Tensor] = None
|
||||
sampling_metadata: SamplingMetadata = None
|
||||
multi_modal_kwargs: BatchedTensorInputs = None
|
||||
adapter_ids: Optional[str] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"input_block_ids": self.input_block_ids,
|
||||
"sampling_metadata": self.sampling_metadata,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForNeuron":
|
||||
return ModelInputForNeuron(
|
||||
input_tokens=tensor_dict["input_tokens"],
|
||||
input_positions=tensor_dict["input_positions"],
|
||||
input_block_ids=tensor_dict["input_block_ids"],
|
||||
sampling_metadata=tensor_dict["sampling_metadata"],
|
||||
multi_modal_kwargs=tensor_dict["multi_modal_kwargs"],
|
||||
)
|
||||
|
||||
|
||||
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
"""A model runner for AWS Neuron hardware"""
|
||||
|
||||
# NEURON has an upper limit on the top_k
|
||||
_MAX_NEURON_SAMPLING_TOP_K = 256
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
|
||||
if (self.model_config is not None
|
||||
and self.model_config.get_sliding_window()):
|
||||
logger.warning("Sliding window is not supported on Neuron. "
|
||||
"The model will run without sliding window.")
|
||||
self.device_config = (self.device_config if self.device_config
|
||||
is not None else DeviceConfig())
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # initialize after load_model.
|
||||
|
||||
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
|
||||
# turn off on-device sampling.
|
||||
self._on_device_sampling_disabled = int(
|
||||
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))
|
||||
|
||||
# NEURON needs to update sampling parameters when request IDs change
|
||||
# across batches. This variable stores the previous batch's request IDs
|
||||
# to determine if an update is needed.
|
||||
self._previous_batch_request_ids: List[str] = []
|
||||
|
||||
if not self._on_device_sampling_disabled:
|
||||
self._init_neuron_sampling()
|
||||
|
||||
def _init_neuron_sampling(self) -> None:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
from transformers_neuronx.config import GenerationConfig
|
||||
else:
|
||||
from transformers import GenerationConfig
|
||||
logger.warning(
|
||||
"On-device sampling is turned on in Neuron by default, only "
|
||||
"top_k, top_p, and temperature are current supported sampling "
|
||||
"parameters. To turn off the on-device sampling, please set "
|
||||
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.")
|
||||
self.model_config.neuron_sampling_params = GenerationConfig(
|
||||
max_length=self.scheduler_config.max_model_len,
|
||||
do_sample=True,
|
||||
per_batch_line=True,
|
||||
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
|
||||
* self.scheduler_config.max_num_seqs,
|
||||
top_p=[1.0] * self.scheduler_config.max_num_seqs,
|
||||
temperature=[1.0] * self.scheduler_config.max_num_seqs,
|
||||
dynamic=True,
|
||||
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_neuron_model(self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
|
||||
BatchedTensorInputs]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
input_block_ids: List[int] = []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
seq_len = len(prompt_tokens)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
input_positions.append(list(range(seq_len)))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
assert len(block_table) == 1
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
mm_kwargs = seq_group_metadata.multi_modal_data
|
||||
if mm_kwargs:
|
||||
mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs)
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_seq_len > 0
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
pad=0,
|
||||
max_len=max_seq_len,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
pad=0,
|
||||
max_len=max_seq_len,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_block_ids = torch.tensor(input_block_ids,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
input_block_ids: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
context_lens.append(seq_len)
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
assert len(block_table) == 1
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
pad=0,
|
||||
max_len=1,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
pad=0,
|
||||
max_len=1,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
input_block_ids = torch.tensor(input_block_ids,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
return input_tokens, input_positions, input_block_ids
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
|
||||
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForNeuron:
|
||||
multi_modal_kwargs = None
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||
seq_lens = None
|
||||
|
||||
if not self._on_device_sampling_disabled:
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
top_k, top_p, temperature = (
|
||||
self._convert_to_neuron_sampling_params(sampling_params))
|
||||
sampling_params.top_k = top_k
|
||||
sampling_params.top_p = top_p
|
||||
sampling_params.temperature = temperature
|
||||
|
||||
# we need multi_modal_data for later tokens as well
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
multi_modal_kwargs_list.append(mm_data)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since neuron worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
|
||||
if current_platform.use_transformers_neuronx(
|
||||
) and not self._on_device_sampling_disabled:
|
||||
# Once the request IDs are changed in current iteration, we will
|
||||
# update the on-device sampling parameters.
|
||||
current_batch_request_ids = [
|
||||
seq_group_meta_data.request_id
|
||||
for seq_group_meta_data in seq_group_metadata_list
|
||||
]
|
||||
if current_batch_request_ids != self._previous_batch_request_ids:
|
||||
self._update_neuron_sampling_params(seq_group_metadata_list)
|
||||
self._previous_batch_request_ids = current_batch_request_ids
|
||||
|
||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
input_block_ids=input_block_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs)
|
||||
|
||||
def _update_neuron_sampling_params(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
|
||||
# Update Neuron sampling parameters (GenerationConfig in Neuron)
|
||||
current_sampling_params = self.model_config.neuron_sampling_params
|
||||
assert current_sampling_params is not None, (
|
||||
f"Failed to update sampling_params, "
|
||||
f"current sampling params is {current_sampling_params}")
|
||||
|
||||
is_update_needed = False
|
||||
|
||||
top_k = current_sampling_params.top_k
|
||||
top_p = current_sampling_params.top_p
|
||||
temperature = current_sampling_params.temperature
|
||||
|
||||
# The index of a sequence's sampling parameters in neuron is equal to
|
||||
# its index in `input_block_ids`.
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
|
||||
seq_group_top_k = sampling_params.top_k
|
||||
seq_group_top_p = sampling_params.top_p
|
||||
seq_group_temperature = sampling_params.temperature
|
||||
|
||||
for seq_id in seq_ids:
|
||||
index = seq_group_metadata.block_tables[seq_id][0]
|
||||
if (top_k[index] != seq_group_top_k
|
||||
or top_p[index] != seq_group_top_p
|
||||
or temperature[index] != seq_group_temperature):
|
||||
is_update_needed = True
|
||||
|
||||
top_k[index] = seq_group_top_k
|
||||
top_p[index] = seq_group_top_p
|
||||
temperature[index] = seq_group_temperature
|
||||
|
||||
# update_generation_config is only available in transformers-neuronx
|
||||
if is_update_needed and current_platform.use_transformers_neuronx():
|
||||
self.model.model.update_generation_config(current_sampling_params)
|
||||
|
||||
def _convert_to_neuron_sampling_params(
|
||||
self, sampling_params: SamplingParams) -> Tuple[int, float, float]:
|
||||
# Returns the top_k, top_p and temperature parameters for neuron.
|
||||
top_k = sampling_params.top_k
|
||||
top_p = sampling_params.top_p
|
||||
temperature = sampling_params.temperature
|
||||
|
||||
if temperature == 0.0:
|
||||
# Enable greedy sampling on zero temperature
|
||||
return (1, 1.0, 1.0)
|
||||
if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
|
||||
top_k = self._MAX_NEURON_SAMPLING_TOP_K
|
||||
|
||||
return (top_k, top_p, temperature)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"NeuronModelRunner does not support multi-step execution.")
|
||||
|
||||
# extract top_k, top_p and temperature from model_input for neuron
|
||||
# forward call
|
||||
sampling_params = (torch.tensor([[
|
||||
seq_group.sampling_params.top_k, seq_group.sampling_params.top_p,
|
||||
seq_group.sampling_params.temperature
|
||||
] for seq_group in model_input.sampling_metadata.seq_groups]))
|
||||
|
||||
if current_platform.use_neuronx_distributed():
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
adapter_ids=model_input.adapter_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
elif current_platform.use_transformers_neuronx():
|
||||
# [TODO] validate on-device sampling
|
||||
# The model signature may need change for on-device sampling
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
# Compute the logits only if the on-device sampling is turned off as
|
||||
# on-device sampling outputs the token ids.
|
||||
if self._on_device_sampling_disabled:
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
else:
|
||||
logits = hidden_states
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return [output]
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def process_multi_modal_data_neuron(self, mm_data):
|
||||
# this is a no-op for NeuronModelRunner
|
||||
return mm_data
|
||||
|
||||
def remove_all_loras(self):
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest):
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
@ -1,189 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A Neuron worker class."""
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.neuron import NeuronFramework
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NeuronWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes the model on a group of neuron cores.
|
||||
"""
|
||||
|
||||
model_runner: NeuronModelRunner
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.lora_config = vllm_config.lora_config
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
neuron_framework = current_platform.get_neuron_framework_to_use()
|
||||
if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX:
|
||||
self.model_runner = self.get_tnx_model_runner(vllm_config)
|
||||
elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE:
|
||||
self.model_runner = self.get_neuronx_distributed_model_runner(
|
||||
vllm_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Specified framework" +
|
||||
f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" +
|
||||
" is either not installed or not supported." +
|
||||
" Supported frameworks: " +
|
||||
"[transformers-neuronx, neuronx-distributed-inference]")
|
||||
|
||||
def get_tnx_model_runner(self, vllm_config):
|
||||
assert (self.lora_config
|
||||
is None), ("LoRA is not supported for TransformersNeuronX "
|
||||
"framework.")
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported for TransformersNeuronX"
|
||||
)
|
||||
return NeuronModelRunner(vllm_config=vllm_config)
|
||||
|
||||
def get_neuronx_distributed_model_runner(self, vllm_config):
|
||||
from vllm.worker.neuronx_distributed_model_runner import (
|
||||
NeuronxDistributedModelRunner)
|
||||
if self.speculative_config is not None:
|
||||
assert (self.lora_config is None), (
|
||||
"LoRA is not supported for Speculative Decoding")
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported for NeuronxDistributed")
|
||||
return NeuronxDistributedModelRunner(vllm_config=vllm_config)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.init_distributed_environment()
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
Swapping is not yet supported, so always return num_cpu_blocks=0.
|
||||
|
||||
We configure num_gpu_blocks to be equal to max_num_seqs.
|
||||
"""
|
||||
# Set the number of GPU blocks to be the same as the maximum number of
|
||||
# sequences that can be processed in a single batch. This is equivalent
|
||||
# to schedule without PagedAttention.
|
||||
num_gpu_blocks = self.scheduler_config.max_num_seqs + 1
|
||||
|
||||
# Swap not yet supported with Neuron backend.
|
||||
num_cpu_blocks = 0
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache.
|
||||
"""
|
||||
|
||||
# Different values are not tested.
|
||||
assert num_cpu_blocks == 0
|
||||
assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
return WorkerInput(num_seq_groups=len(
|
||||
execute_model_req.seq_group_metadata_list), )
|
||||
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Determine the size in bytes of a cache block.
|
||||
|
||||
This is required for speculative decoding; it is not yet implemented.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def init_distributed_environment(self):
|
||||
"""Neuron uses transformers-neuronx for tensor parallelism.
|
||||
|
||||
vLLM still needs the environment initialized when TP/PP > 1
|
||||
"""
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
backend=current_platform.dist_backend,
|
||||
)
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
1,
|
||||
1,
|
||||
)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.list_loras()
|
@ -1,294 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import torch
|
||||
from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import (
|
||||
get_all_supported_aspect_ratios)
|
||||
from neuronx_distributed_inference.modules.generation.sampling import (
|
||||
prepare_sampling_params)
|
||||
from neuronx_distributed_inference.modules.lora_serving import (
|
||||
LoraCheckpoint, LoraServingConfig)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.neuronx_distributed import (
|
||||
_get_model_architecture, get_neuron_model)
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
|
||||
NeuronModelRunner)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NeuronxDistributedModelRunner(NeuronModelRunner):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
self.lora_checkpoint = None
|
||||
self.model = None
|
||||
self.lora_serving_config = None
|
||||
|
||||
@staticmethod
|
||||
def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]):
|
||||
if not lora_modules:
|
||||
return None
|
||||
return {_.get("name"): _.get("path") for _ in lora_modules}
|
||||
|
||||
def _get_nxdi_lora_config(self):
|
||||
override_neuron_config = self.model_config.override_neuron_config
|
||||
lora_modules = override_neuron_config.pop("lora_modules", None)
|
||||
target_modules = override_neuron_config.pop("target_modules", None)
|
||||
lora_ckpt_paths = self._get_lora_paths_strings(lora_modules)
|
||||
if self.lora_config.max_loras < len(lora_ckpt_paths):
|
||||
raise ValueError(
|
||||
"Number of LoRAs (%s) exceeds maximum "
|
||||
"allowed (%s)", len(lora_ckpt_paths),
|
||||
self.lora_config.max_loras)
|
||||
|
||||
return LoraServingConfig(
|
||||
max_loras=self.lora_config.max_loras,
|
||||
max_lora_rank=self.lora_config.max_lora_rank,
|
||||
target_modules=target_modules,
|
||||
lora_ckpt_paths=lora_ckpt_paths,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
# Update LoRA config
|
||||
if self.lora_config is not None:
|
||||
self.lora_serving_config = self._get_nxdi_lora_config()
|
||||
self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config)
|
||||
self.model = get_neuron_model(
|
||||
self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
lora_serving_config=self.lora_serving_config)
|
||||
|
||||
def get_nxd_sampling_params(self, sampling_metadata):
|
||||
if self.model.config.neuron_config.on_device_sampling_config:
|
||||
max_topk = (self.model.config.neuron_config.
|
||||
on_device_sampling_config.global_topk)
|
||||
else:
|
||||
max_topk = self.model.config.vocab_size
|
||||
|
||||
top_k = [1] * self.scheduler_config.max_num_seqs
|
||||
top_p = [1.0] * self.scheduler_config.max_num_seqs
|
||||
temperature = [1.0] * self.scheduler_config.max_num_seqs
|
||||
|
||||
for index, sequenceGroupToSample in enumerate(
|
||||
sampling_metadata.seq_groups):
|
||||
top_k[index] = (sequenceGroupToSample.sampling_params.top_k
|
||||
if sequenceGroupToSample.sampling_params.top_k > 0
|
||||
else max_topk)
|
||||
top_p[index] = sequenceGroupToSample.sampling_params.top_p
|
||||
temperature[index] = (
|
||||
sequenceGroupToSample.sampling_params.temperature)
|
||||
|
||||
sampling_params = prepare_sampling_params(
|
||||
batch_size=self.scheduler_config.max_num_seqs,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
return sampling_params
|
||||
|
||||
def get_multi_modal_data_neuron(self, input_images):
|
||||
raise NotImplementedError("need to restore multi-modal support")
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"NeuronModelRunner does not support multi-step execution.")
|
||||
|
||||
if _get_model_architecture(
|
||||
self.model.config) != "MllamaForConditionalGeneration":
|
||||
return super().execute_model(model_input, kv_caches,
|
||||
intermediate_tensors, num_steps)
|
||||
|
||||
sampling_params = self.get_nxd_sampling_params(
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if model_input.multi_modal_kwargs.get('pixel_values') is not None:
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
seq_ids=model_input.input_block_ids,
|
||||
pixel_values=model_input.multi_modal_kwargs.get(
|
||||
'pixel_values'),
|
||||
aspect_ratios=model_input.multi_modal_kwargs.get(
|
||||
'aspect_ratios'),
|
||||
sampling_params=sampling_params,
|
||||
num_chunks=model_input.multi_modal_kwargs.get('num_chunks'),
|
||||
has_image=model_input.multi_modal_kwargs.get(
|
||||
'has_image').squeeze(1),
|
||||
)
|
||||
else:
|
||||
bs = model_input.input_tokens.shape[0] if (model_input.input_tokens
|
||||
is not None) else 1
|
||||
empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560],
|
||||
dtype=torch.bfloat16)
|
||||
empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64)
|
||||
num_chunks = torch.zeros((bs, 1), dtype=torch.int32)
|
||||
has_image = torch.zeros([bs], dtype=torch.int32)
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
seq_ids=model_input.input_block_ids,
|
||||
pixel_values=empty_pixel_values,
|
||||
aspect_ratios=empty_aspect_ratios,
|
||||
sampling_params=sampling_params,
|
||||
num_chunks=num_chunks,
|
||||
has_image=has_image,
|
||||
)
|
||||
|
||||
output = self.model.sample(
|
||||
hidden_states=hidden_states,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
|
||||
return [output]
|
||||
|
||||
def process_multi_modal_data_neuron(self, mm_data):
|
||||
# Neuron uses aspect_ratios instead of aspect_ratio_ids
|
||||
all_supported_aspect_ratios = get_all_supported_aspect_ratios(
|
||||
self.model.config.vision_config.max_num_tiles)
|
||||
aspect_ratio_ids = mm_data.get("aspect_ratio_ids")
|
||||
mm_data["aspect_ratios"] = torch.tensor(
|
||||
all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0)
|
||||
|
||||
# Neuron's num_chunks is HF's num_tiles
|
||||
mm_data["num_chunks"] = mm_data.get("num_tiles")
|
||||
|
||||
# Input has an image if it has pixel_values
|
||||
bs = mm_data["num_chunks"].shape[0]
|
||||
pixel_values = mm_data.get("pixel_values")
|
||||
if pixel_values is not None and not torch.all(pixel_values == 0):
|
||||
mm_data["has_image"] = torch.ones(bs)
|
||||
|
||||
else:
|
||||
mm_data["has_image"] = torch.zeros(bs)
|
||||
return mm_data
|
||||
|
||||
def _get_lora_adapter_ids(self, seq_group_metadata_list):
|
||||
# set LoRA adapter IDs for multi-lora serving
|
||||
batch_size = len(seq_group_metadata_list)
|
||||
if self.lora_checkpoint is not None:
|
||||
# "0" indicates NxDI to use the base model for inference
|
||||
adapter_ids = ["0"] * batch_size
|
||||
for idx, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
if seq_group_metadata.lora_request is not None:
|
||||
adapter_ids[
|
||||
idx] = seq_group_metadata.lora_request.lora_name
|
||||
|
||||
# convert adapter_ids from strings to integers
|
||||
adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices(
|
||||
adapter_ids, batch_size)
|
||||
else:
|
||||
adapter_ids = torch.zeros((batch_size), dtype=torch.int32)
|
||||
|
||||
return adapter_ids
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForNeuron:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||
seq_lens = None
|
||||
|
||||
if not self._on_device_sampling_disabled:
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
top_k, top_p, temperature = (
|
||||
self._convert_to_neuron_sampling_params(sampling_params))
|
||||
sampling_params.top_k = top_k
|
||||
sampling_params.top_p = top_p
|
||||
sampling_params.temperature = temperature
|
||||
|
||||
# we need multi_modal_data for later tokens as well
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
multi_modal_kwargs_list.append(mm_data)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since neuron worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
|
||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
input_block_ids=input_block_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
adapter_ids=lora_adapter_ids)
|
||||
|
||||
def remove_all_loras(self):
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest):
|
||||
logger.warning(
|
||||
"Adding LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config. If you supplied "
|
||||
"the parameter, you can ignore this warning. Ignoring"
|
||||
"lora request: ", lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
Reference in New Issue
Block a user