[V0 deprecation] Deprecate V0 Neuron backend (#21159)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-09-06 16:15:18 -07:00
committed by GitHub
parent 848562bd49
commit 4172235ab7
46 changed files with 10 additions and 5462 deletions

View File

@ -149,19 +149,3 @@ steps:
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
env:
DOCKER_BUILDKIT: "1"
- block: "Build Neuron release image"
key: block-neuron-release-image-build
depends_on: ~
- label: "Build and publish Neuron release image"
depends_on: block-neuron-release-image-build
agents:
queue: neuron-postmerge
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest"
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)"
env:
DOCKER_BUILDKIT: "1"

View File

@ -1,64 +0,0 @@
#!/bin/bash
# This script build the Neuron docker image and run the API server inside the container.
# It serves a sanity check for compilation and basic model usage.
set -e
set -v
image_name="neuron/vllm-ci"
container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
HF_CACHE="$(realpath ~)/huggingface"
mkdir -p "${HF_CACHE}"
HF_MOUNT="/root/.cache/huggingface"
HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN)
NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache"
mkdir -p "${NEURON_COMPILE_CACHE_URL}"
NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache"
# Try building the docker image
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws
# prune old image and containers to save disk space, and only once a day
# by using a timestamp file in tmp.
if [ -f /tmp/neuron-docker-build-timestamp ]; then
last_build=$(cat /tmp/neuron-docker-build-timestamp)
current_time=$(date +%s)
if [ $((current_time - last_build)) -gt 86400 ]; then
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune -f
# Remove unused volumes / force the system prune for old images as well.
docker volume prune -f && docker system prune -f
echo "$current_time" > /tmp/neuron-docker-build-timestamp
fi
else
date "+%s" > /tmp/neuron-docker-build-timestamp
fi
docker build -t "${image_name}" -f docker/Dockerfile.neuron .
# Setup cleanup
remove_docker_container() {
docker image rm -f "${image_name}" || true;
}
trap remove_docker_container EXIT
# Run the image
docker run --rm -it --device=/dev/neuron0 --network bridge \
-v "${HF_CACHE}:${HF_MOUNT}" \
-e "HF_HOME=${HF_MOUNT}" \
-e "HF_TOKEN=${HF_TOKEN}" \
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
--name "${container_name}" \
${image_name} \
/bin/bash -c "
set -e; # Exit on first error
python3 /workspace/vllm/examples/offline_inference/neuron.py;
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
echo \"Running test file: \$f\";
python3 -m pytest \$f -v --capture=tee-sys;
done
"

View File

@ -2,7 +2,6 @@ include LICENSE
include requirements/common.txt
include requirements/cuda.txt
include requirements/rocm.txt
include requirements/neuron.txt
include requirements/cpu.txt
include CMakeLists.txt

View File

@ -1,56 +0,0 @@
# default base image
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04"
FROM $BASE_IMAGE
RUN echo "Base image is $BASE_IMAGE"
# Install some basic utilities
RUN apt-get update && \
apt-get install -y \
git \
python3 \
python3-pip \
ffmpeg libsm6 libxext6 libgl1
### Mount Point ###
# When launching the container, mount the code directory to /workspace
ARG APP_MOUNT=/workspace
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}/vllm
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity
RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install pytest
# uninstall transformers-neuronx package explicitly to avoid version conflict
RUN python3 -m pip uninstall -y transformers-neuronx
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install -U \
'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements/neuron.txt
ENV VLLM_TARGET_DEVICE neuron
RUN --mount=type=bind,source=.git,target=.git \
pip install --no-build-isolation -v -e .
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
# install transformers-neuronx package as an optional dependencies (for V0)
# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict
RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
# overwrite entrypoint to run bash script
RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py
CMD ["/bin/bash"]

View File

@ -1,49 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=1024,
block_size=1024,
# ruff: noqa: E501
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
tensor_parallel_size=2,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()

View File

@ -1,61 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to run offline inference with an EAGLE speculative
decoding model on neuron. To use EAGLE speculative decoding, you must use
a draft model that is specifically fine-tuned for EAGLE speculation.
Additionally, to use EAGLE with NxD Inference, the draft model must include
the LM head weights from the target model. These weights are shared between
the draft and target model.
"""
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"What is annapurna labs?",
]
def main():
# Create a sampling params object.
sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True)
# Create an LLM.
llm = LLM(
model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct",
speculative_config={
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
"num_speculative_tokens": 5,
"max_model_len": 2048,
},
max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in neuronx-distributed-inference.
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
tensor_parallel_size=32,
override_neuron_config={
"enable_eagle_speculation": True,
"enable_fused_speculation": True,
},
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, \n\n\n Generated text: {generated_text!r}")
if __name__ == "__main__":
main()

View File

@ -1,63 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from vllm import LLM, SamplingParams
# creates XLA hlo graphs for all the context length buckets.
os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
# Quantizes neuron model weight to int8 ,
# The default config for quantization is int8 dtype.
os.environ["NEURON_QUANT_DTYPE"] = "s8"
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=2048,
block_size=2048,
# ruff: noqa: E501
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
quantization="neuron_quant",
override_neuron_config={
"cast_logits_dtype": "bfloat16",
},
tensor_parallel_size=2,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()

View File

@ -1,110 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import requests
import torch
from neuronx_distributed_inference.models.mllama.utils import add_instruct
from PIL import Image
from vllm import LLM, SamplingParams, TextPrompt
def get_image(image_url):
image = Image.open(requests.get(image_url, stream=True).raw)
return image
# Model Inputs
PROMPTS = [
"What is in this image? Tell me a story",
"What is the recipe of mayonnaise in two sentences?",
"Describe this image",
"What is the capital of Italy famous for?",
]
IMAGES = [
get_image(
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
),
None,
get_image(
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
),
None,
]
SAMPLING_PARAMS = [
dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16)
for _ in range(len(PROMPTS))
]
def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params):
# Prepare all inputs for mllama generation, including:
# 1. put text prompt into instruct chat template
# 2. compose single text and single image prompt into Vllm's prompt class
# 3. prepare sampling parameters
input_image = single_image
has_image = torch.tensor([1])
if isinstance(single_image, torch.Tensor) and single_image.numel() == 0:
has_image = torch.tensor([0])
instruct_prompt = add_instruct(prompt, has_image)
inputs = TextPrompt(prompt=instruct_prompt)
if input_image is not None:
inputs["multi_modal_data"] = {"image": input_image}
sampling_params = SamplingParams(**sampling_params)
return inputs, sampling_params
def print_outputs(outputs):
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def main():
assert (
len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS)
), f"""Text, image prompts and sampling parameters should have the
same batch size; but got {len(PROMPTS)}, {len(IMAGES)},
and {len(SAMPLING_PARAMS)}"""
# Create an LLM.
llm = LLM(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
max_num_seqs=1,
max_model_len=4096,
block_size=4096,
device="neuron",
tensor_parallel_size=32,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True,
"save_sharded_checkpoint": True,
"on_device_sampling_config": {
"global_topk": 1,
"dynamic": False,
"deterministic": False,
},
},
)
batched_inputs = []
batched_sample_params = []
for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS):
inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params)
# test batch-size = 1
outputs = llm.generate(inputs, sampling_params)
print_outputs(outputs)
batched_inputs.append(inputs)
batched_sample_params.append(sampling_params)
# test batch-size = 4
outputs = llm.generate(batched_inputs, batched_sample_params)
print_outputs(outputs)
if __name__ == "__main__":
main()

View File

@ -1,64 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to run offline inference with a speculative
decoding model on neuron.
"""
import os
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, I am a language model and I can help",
"The president of the United States is",
"The capital of France is",
]
def config_buckets():
"""Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets.
os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
def initialize_llm():
"""Create an LLM with speculative decoding."""
return LLM(
model="openlm-research/open_llama_7b",
speculative_config={
"model": "openlm-research/open_llama_3b",
"num_speculative_tokens": 4,
"max_model_len": 2048,
},
max_num_seqs=4,
max_model_len=2048,
block_size=2048,
device="neuron",
tensor_parallel_size=32,
)
def process_requests(llm: LLM, sampling_params: SamplingParams):
"""Generate texts from prompts and print them."""
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def main():
"""Main function that sets up the llm and processes prompts."""
config_buckets()
llm = initialize_llm()
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, top_k=1)
process_requests(llm, sampling_params)
if __name__ == "__main__":
main()

View File

@ -1,9 +0,0 @@
# Common dependencies
-r common.txt
# Dependencies for Neuron devices
packaging>=24.2
setuptools>=77.0.3,<80.0.0
torch-neuronx >= 2.5.0
neuronx-cc>=2.0.0a0
torchvision # Required for Llama3.2 multimodal image preprocessing

View File

@ -413,8 +413,7 @@ def _no_device() -> bool:
def _is_cuda() -> bool:
has_cuda = torch.version.cuda is not None
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
and not (_is_neuron() or _is_tpu()))
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu())
def _is_hip() -> bool:
@ -422,10 +421,6 @@ def _is_hip() -> bool:
or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
def _is_neuron() -> bool:
return VLLM_TARGET_DEVICE == "neuron"
def _is_tpu() -> bool:
return VLLM_TARGET_DEVICE == "tpu"
@ -470,25 +465,6 @@ def get_rocm_version():
return None
def get_neuronxcc_version():
import sysconfig
site_dir = sysconfig.get_paths()["purelib"]
version_file = os.path.join(site_dir, "neuronxcc", "version",
"__init__.py")
# Check if the command was executed successfully
with open(version_file) as fp:
content = fp.read()
# Extract the version using a regular expression
match = re.search(r"__version__ = '(\S+)'", content)
if match:
# Return the version string
return match.group(1)
else:
raise RuntimeError("Could not find Neuron version in the output")
def get_nvcc_cuda_version() -> Version:
"""Get the CUDA version from nvcc.
@ -541,12 +517,6 @@ def get_vllm_version() -> str:
rocm_version = get_rocm_version() or torch.version.hip
if rocm_version and rocm_version != MAIN_CUDA_VERSION:
version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
elif _is_neuron():
# Get the Neuron version
neuron_version = str(get_neuronxcc_version())
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"{sep}neuron{neuron_version_str}"
elif _is_tpu():
version += f"{sep}tpu"
elif _is_cpu():
@ -591,8 +561,6 @@ def get_requirements() -> list[str]:
requirements = modified_requirements
elif _is_hip():
requirements = _read_requirements("rocm.txt")
elif _is_neuron():
requirements = _read_requirements("neuron.txt")
elif _is_tpu():
requirements = _read_requirements("tpu.txt")
elif _is_cpu():
@ -601,7 +569,7 @@ def get_requirements() -> list[str]:
requirements = _read_requirements("xpu.txt")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
"Unsupported platform, please use CUDA, ROCm, or CPU.")
return requirements

View File

@ -287,15 +287,6 @@ def test_prefix_cache_default():
},
"mm-processor-kwargs"
),
(
'{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}',
{
"cast_logits_dtype": "bfloat16",
"sequence_parallel_norm": True,
"sequence_parallel_norm_threshold": 2048,
},
"override-neuron-config"
),
])
# yapf: enable
def test_composite_arg_parser(arg, expected, option):

View File

@ -1,43 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.activation import FastGELU, SiluAndMul
from vllm.platforms import current_platform
@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"])
@pytest.mark.parametrize("num_tokens,d,dtype", [
(7, 512, torch.half),
(7, 512, torch.float),
(83, 512, torch.half),
])
@torch.inference_mode()
def test_act_and_mul(
activation: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device)
if activation == "silu_and_mul":
layer = SiluAndMul()
fn = layer.forward_native
elif activation == "gelu_fast":
layer = FastGELU()
fn = F.gelu
else:
raise NotImplementedError(
f"activation {activation} is not implemented.")
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device).forward_neuron(x)
ref_out = fn(x.cpu())
torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0)

View File

@ -1,154 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import neuronxcc.nki.language as nl
import pytest
import torch
import torch.nn.functional as F
from neuronxcc import nki
from vllm.attention.ops.nki_flash_attn import (
load_block_tables, transform_block_tables_for_indirect_load)
def is_power_of_2(n):
return n > 0 and (n & (n - 1) == 0)
def nki_load_and_transform_block_tables(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
num_blocks_per_tile)
# we need to pass an Index as head_id
head_id = nl.arange(1)[None, :] + head_id
block_tables_transposed = transform_block_tables_for_indirect_load(
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
B_P_SIZE = 128
assert block_tables_transposed.shape[1] == B_P_SIZE
out = nl.ndarray(
block_tables_transposed.shape,
dtype=nl.int32,
buffer=nl.shared_hbm,
)
for i in nl.affine_range(block_tables_transposed.shape[0]):
nl.store(dst=out[i], value=block_tables_transposed[i])
return out
def ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert block_tables.numel() == num_tiles * num_blocks_per_tile
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
B_F_SIZE = 128
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
block_tables = F.pad(
block_tables,
(0, 0, 0, num_tiles_padded - num_tiles),
"constant",
0,
)
block_tables = block_tables * num_head + head_id
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
block_tables = block_tables * block_size_tiling_factor + offset
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()
num_blocks_per_tile = block_tables_transposed.shape[0]
assert num_blocks_per_tile % B_F_SIZE == 0
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
B_F_SIZE, num_tiles_padded)
@pytest.mark.parametrize(
"q_head_per_kv_head,head_id",
[
(1, 0),
(3, 1),
],
)
@pytest.mark.parametrize(
"num_tiles,num_blocks_per_tile",
[
(1, 1),
(13, 16),
(17, 128),
(35, 512),
(128, 128),
(130, 64),
(280, 256),
(315, 1),
],
)
@torch.inference_mode()
def test_load_and_transform_block_tables(
monkeypatch: pytest.MonkeyPatch,
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
) -> None:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
compiler_flags_str = " ".join([
"-O1",
"--retry_failed_compilation",
])
with monkeypatch.context() as m:
m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
torch.manual_seed(10000)
torch.set_printoptions(sci_mode=False)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
B_P_SIZE = 128
if num_blocks_per_tile < B_P_SIZE:
assert B_P_SIZE % num_blocks_per_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
else:
block_size_tiling_factor = 1
max_num_blocks = 100000
block_tables = torch.randint(
0,
max_num_blocks,
(num_tiles * num_blocks_per_tile, ),
dtype=torch.int32,
)
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
block_tables.to(device=device),
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
).cpu()
ref_out = ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
)
assert (nki_out.shape == ref_out.shape
), f"{nki_out.shape=} != {ref_out.shape=}"
assert torch.all(nki_out == ref_out)

View File

@ -1,86 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.attention.ops.nki_flash_attn import reshape_and_cache
@pytest.mark.parametrize(
"num_tokens, n_kv_head, d_head, num_blocks, block_size",
[
# Small model configuration (e.g., GPT-2 small)
(32, 12, 64, 4, 128), # Typical sequence processing
(1, 12, 64, 4, 128), # Single token update
(128, 12, 64, 4, 128), # Longer sequence
# Medium model configuration (e.g., GPT-2 medium)
(64, 16, 96, 8, 256), # Standard batch
(256, 16, 96, 8, 256), # Large batch
# Large model configuration (e.g., GPT-3 style)
(48, 32, 128, 16, 512), # Typical processing window
(512, 32, 128, 16, 512), # Full context window
# Edge cases and stress tests
(1024, 8, 32, 32, 32), # Many tokens, small heads
(16, 64, 256, 4, 64), # Few tokens, many heads
(2048, 24, 128, 64, 128), # Large scale test
# Minimal configurations for debugging
(4, 2, 16, 2, 16), # Tiny test case
(1, 1, 8, 1, 8), # Minimal possible
])
def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
block_size):
# Set random seed for reproducibility
torch.manual_seed(42)
# Create CPU tensors for reference implementation
key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
torch.tensor(d_head))
value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
torch.tensor(d_head))
key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens]
# Run reference implementation on CPU
block_indices = torch.div(slot_mapping_cpu,
block_size,
rounding_mode="floor")
block_offsets = slot_mapping_cpu % block_size
for i in range(num_tokens):
block_idx = block_indices[i]
block_offset = block_offsets[i]
key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i]
value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i]
# Create XLA device tensors
device = torch.device('xla')
key = key_cpu.to(device)
value = value_cpu.to(device)
key_cache = torch.zeros_like(key_cache_cpu, device=device)
value_cache = torch.zeros_like(value_cache_cpu, device=device)
slot_mapping = slot_mapping_cpu.to(device)
kv_cache = torch.stack([key_cache, value_cache])
# Run vectorized implementation on XLA device
reshape_and_cache(key, value, kv_cache, slot_mapping)
key_cache, value_cache = torch.unbind(kv_cache, dim=0)
# Move results back to CPU for comparison
key_cache_result = key_cache.cpu()
value_cache_result = value_cache.cpu()
# Assert results match
torch.testing.assert_close(key_cache_result,
key_cache_cpu,
rtol=1e-5,
atol=1e-5)
torch.testing.assert_close(value_cache_result,
value_cache_cpu,
rtol=1e-5,
atol=1e-5)

View File

@ -1,57 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
(7, 8, False, torch.half),
(83, 768, False, torch.half),
(83, 768, True, torch.half),
(83, 768, True, torch.bfloat16),
(83, 768, True, torch.float32),
])
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None
residual_cpu = residual.cpu() if add_residual else None
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device)(x, residual)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if add_residual:
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out[0].cpu(),
ref_out[0],
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out[1].cpu(),
ref_out[1],
atol=1e-2,
rtol=1e-2)
else:
assert out.is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)

View File

@ -1,95 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available
class MockLogitsProcessor(LogitsProcessor):
def __init__(self, vocab_size: int, scale: float,
fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size, scale=scale)
self.fake_logits = fake_logits.clone()
def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
lambda x, y: x
), patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
return input_tensor, fake_logits, logits_processor
RANDOM_SEEDS = list(range(8))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_logits_processors(seed: int):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
set_random_seed(seed)
torch.set_default_device("cpu")
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits
seq_group_metadata_list = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
lm_head=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
fake_logits *= logits_processor.scale
torch.testing.assert_close(logits_processor_output[:, 1],
fake_logits[:, 1],
rtol=1e-4,
atol=0.0)

View File

@ -1,127 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from unittest.mock import MagicMock
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import NeuronModelRunner
os.environ[
'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value
def _create_neuron_model_runner(model: str, *args,
**kwargs) -> NeuronModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
vllm_config = VllmConfig(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
)
neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config)
return neuron_model_runner
def test_update_neuron_sampling_params_not_full_batch():
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
model_runner = _create_neuron_model_runner(
"facebook/opt-125m",
seed=0,
dtype="float16",
max_num_seqs=2,
)
assert not model_runner._on_device_sampling_disabled
# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if current_platform.use_transformers_neuronx():
model_mock = MagicMock()
model_runner.model = model_mock
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0.5,
top_k=1,
top_p=0.5),
block_tables={0: [1]},
)
]
model_runner.prepare_model_input(seq_group_metadata_list)
# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: default sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params = (
model_runner.model_config.neuron_sampling_params)
assert neuron_sampling_params.temperature == [1.0, 0.5]
assert neuron_sampling_params.top_k == [
model_runner._MAX_NEURON_SAMPLING_TOP_K, 1
]
assert neuron_sampling_params.top_p == [1.0, 0.5]
model_mock.model.update_generation_config.assert_called_once_with(
neuron_sampling_params)
def test_update_neuron_sampling_params_full_batch():
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
model_runner = _create_neuron_model_runner(
"facebook/opt-125m",
seed=0,
dtype="float16",
max_num_seqs=2,
)
assert not model_runner._on_device_sampling_disabled
# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if current_platform.use_transformers_neuronx():
model_mock = MagicMock()
model_runner.model = model_mock
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0.5,
top_k=1,
top_p=0.5),
block_tables={0: [1]},
),
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={1: SequenceData.from_seqs([4, 5, 6])},
sampling_params=SamplingParams(temperature=0.2,
top_k=2,
top_p=0.2),
block_tables={1: [0]},
)
]
model_runner.prepare_model_input(seq_group_metadata_list)
# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: sequence 1's sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params = (
model_runner.model_config.neuron_sampling_params)
assert neuron_sampling_params.temperature == [0.2, 0.5]
assert neuron_sampling_params.top_k == [2, 1]
assert neuron_sampling_params.top_p == [0.2, 0.5]
model_mock.model.update_generation_config.assert_called_once_with(
neuron_sampling_params)

View File

@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
def test_get_supported_act_dtypes():
neuron_quant_config = NeuronQuantConfig()
supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes()
target_list = ["any_dtype1", "any_dtype2"]
for dtype in target_list:
assert dtype in supported_act_dtypes

View File

@ -1,514 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from vllm.utils import cdiv
class BlockDiagonalCausalFromBottomRightMask:
@staticmethod
def _from_seqlens(query_lens, seq_lens, block_size=None):
from torch import logical_and, logical_or
contexted = block_size is None
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
n_queries = sum(query_lens)
num_seqs = len(query_lens)
if contexted:
key_lens_blockaligned = seq_lens
else:
n_blocks_per_seq = (context_lens + block_size - 1) // block_size
offset_per_seq = n_blocks_per_seq * block_size
key_lens_blockaligned = offset_per_seq[:num_seqs].tolist()
n_keys = sum(key_lens_blockaligned)
a = (torch.arange(n_queries).reshape(n_queries,
1).expand(n_queries, n_keys))
b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys)
q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0)
k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0)
prior_mask = torch.zeros(n_queries, n_keys)
new_masks: list[torch.Tensor] = []
for seq_id in range(num_seqs):
ri = q_cumsum[seq_id]
ci = k_cumsum[seq_id]
nr = query_lens[seq_id]
if contexted:
nc = seq_lens[seq_id]
a_offset = ci + nc - ri - nr
new_mask = (a + a_offset) >= b
else:
nc = context_lens[seq_id]
a_offset = ci + nc - 1
new_mask = a_offset >= b
left_mask = b >= ci
top_mask = a >= ri
bottom_mask = a < (ri + nr)
new_mask = logical_and(
logical_and(logical_and(new_mask, left_mask), top_mask),
bottom_mask,
)
prior_mask = logical_or(prior_mask, new_mask)
new_masks = new_masks + [new_mask]
return prior_mask
@staticmethod
def from_seqlens(query_lens, seq_lens, block_size=None):
contexted = block_size is None
if contexted:
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
query_lens, seq_lens)
active_mask = None
else:
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
query_lens, seq_lens, block_size)
active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
query_lens, query_lens)
return prior_mask, active_mask
def ref_softmax(x: torch.Tensor,
dim: int,
mixed_precision=False,
return_max_reduce=False):
max_value = torch.amax(x, dim=dim, keepdims=True)
exp = torch.exp(x - max_value)
if mixed_precision:
sum_value = torch.sum(exp.astype(torch.float32),
dim=dim,
keepdims=True).astype(x.dtype)
else:
sum_value = torch.sum(exp, dim=dim, keepdims=True)
if return_max_reduce:
return exp / sum_value, max_value, torch.reciprocal(sum_value)
return exp / sum_value
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor] = None,
return_max_reduce: Optional[bool] = False,
) -> torch.Tensor:
scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
masked_score = scaled_qk + attn_mask.float()
if return_max_reduce:
norm_score, cached_max, cached_sum_reciprocal = ref_softmax(
masked_score, dim=-1, return_max_reduce=True)
else:
norm_score = ref_softmax(masked_score, dim=-1)
out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
if return_max_reduce:
return (
out,
cached_max,
cached_sum_reciprocal,
norm_score,
masked_score,
scaled_qk,
)
else:
return (out, )
def ref_context_attention(
query,
key,
value,
query_lens,
seq_lens,
head_size,
num_queries_per_kv,
return_max_reduce=False,
):
scale = float(1.0 / (head_size**0.5))
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens)
# convert binary mask to -inf values
attn_mask = torch.logical_not(attn_mask)
attn_mask = attn_mask.float() * -30000
output, *debug_tensors = ref_masked_attention(
query,
key,
value,
scale,
attn_mask,
return_max_reduce=return_max_reduce,
)
output = output.unsqueeze(1)
if return_max_reduce:
cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
debug_tensors)
return (
output,
cached_max,
cached_sum_reciprocal,
lse,
masked_score,
scaled_qk,
)
else:
return output
def sample_inputs(
prefill_batch_size,
decode_batch_size,
min_query_len,
max_query_len,
min_ctx_len,
max_ctx_len,
block_size,
num_heads,
num_kv_heads,
head_size,
dtype,
):
batch_size = prefill_batch_size + decode_batch_size
max_model_len = (max_query_len + max_ctx_len) * 4
max_block_per_request = max_model_len // block_size
cache_size = (batch_size * max_block_per_request) + 2
prefill_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (prefill_batch_size, ),
dtype=torch.long).tolist()
decode_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (decode_batch_size, ),
dtype=torch.long).tolist()
ctx_lens = prefill_ctx_lens + decode_ctx_lens
query_lens = torch.randint(
min_query_len,
max_query_len + 1,
(prefill_batch_size, ),
dtype=torch.long,
).tolist() + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1, 1)
torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1, 1)
key, value = kv.unbind(dim=1)
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:batch_size * max_block_per_request].view(
batch_size, max_block_per_request)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
dim=0)
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long),
dim=0)
for i in range(batch_size):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
kv_cache = torch.stack([k_cache, v_cache])
return (
query,
k,
v,
kv_cache,
block_table,
key,
value,
query_lens,
seq_lens,
)
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
context_lens = seq_lens - query_lens
blocks_per_seq = (context_lens + block_size - 1) // block_size
num_seqs = len(seq_lens)
active_blocks: list[int] = []
for seq_id in range(num_seqs):
active_blocks = (
active_blocks +
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
return F.pad(
torch.tensor(active_blocks, dtype=torch.int32),
(0, num_blocks - len(active_blocks)),
"constant",
0,
)
@pytest.mark.parametrize(
"prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision",
[
# Test minimal configurations (small block size)
(1, 199, 1, 512, 4, 2, 8, False
), # minimal block size, small dimensions
(1, 199, 1, 512, 4, 2, 8, True), # same with mixed precision
# Test common/medium configurations
(4, 12, 32, 2048, 32, 8, 64, False), # common case, larger heads
(4, 12, 32, 2048, 16, 4, 32,
True), # medium size, mixed precision, grouped-query attention (GQA)
# Test large configurations
(4, 12, 256, 8192, 8, 1, 128, False), # large blocks, large head size
(4, 12, 256, 8192, 64, 8, 64, True), # large blocks, many heads
# Test asymmetric configurations
(2, 24, 64, 4096, 12, 4, 96, False), # varied batch sizes
(8, 8, 128, 2048, 24, 2, 48, True), # balanced batches
# Test edge cases
(1, 128, 16, 1024, 4, 2, 16, False), # large decode batch
(16, 4, 8, 1024, 4, 2, 128, True), # large prefill batch
(4, 12, 32, 2048, 16, 1, 32, True), # multi-head attention (MHA)
(4, 12, 32, 2048, 16, 16, 32, True), # multi-query attention (MQA)
])
@torch.inference_mode()
def test_contexted_kv_attention(
monkeypatch: pytest.MonkeyPatch,
prefill_batch_size: int,
decode_batch_size: int,
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import torch_xla.core.xla_model as xm
from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
reorder_context_mask)
assert large_tile_size % block_size == 0
device = xm.xla_device()
compiler_flags_str = " ".join([
"-O1",
"--retry_failed_compilation",
])
with monkeypatch.context() as m:
m.setenv("NEURON_CC_FLAGS", compiler_flags_str)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
torch.set_default_device("cpu")
dtype = torch.float32
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
num_kv_heads = num_heads // num_queries_per_kv
(
query,
k_active,
v_active,
kv_cache,
block_table,
key,
value,
query_lens,
seq_lens,
) = sample_inputs(
prefill_batch_size=prefill_batch_size,
decode_batch_size=decode_batch_size,
min_query_len=min_query_len,
max_query_len=max_query_len,
min_ctx_len=min_ctx_len,
max_ctx_len=max_ctx_len,
block_size=block_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
)
output_ref = ref_context_attention(
query,
key,
value,
query_lens,
seq_lens,
head_size,
num_queries_per_kv,
return_max_reduce=False,
)
# build neuron program
B_P_SIZE = 128
assert (large_tile_size >= B_P_SIZE
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
def pad_to_multiple(a, b):
return cdiv(a, b) * b
def pad_to_next_power_of_2(a):
assert a > 0
return 2**int(a - 1).bit_length()
# calculate input shapes
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks = cdiv(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
large_tile_size // block_size)
context_kv_len = num_active_blocks * block_size
assert (
context_kv_len %
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors
pad_dims = (
0,
0,
0,
0,
0,
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k_active, pad_dims, "constant", 0)
v = F.pad(v_active, pad_dims, "constant", 0)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
# key: (1, n_kv_heads, d, seq_k)
# value: (1, n_kv_heads, seq_v, d)
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous()
# transform block table
active_block_table = get_active_block_tables(
block_table.cpu(),
torch.tensor(query_lens).cpu(),
torch.tensor(seq_lens).cpu(),
block_size,
num_active_blocks,
)
# Build attention masks
prior_mask, active_mask = (
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens, block_size=block_size))
prior_mask_padded = F.pad(
prior_mask,
(
0,
context_kv_len - prior_mask.shape[1],
0,
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
).bool()
active_mask_padded = F.pad(
active_mask,
(
0,
max_num_queries - active_mask.shape[1],
0,
max_num_queries - active_mask.shape[0],
),
"constant",
0,
).bool()
attn_mask = torch.concat([prior_mask_padded, active_mask_padded],
dim=1)
attn_mask = reorder_context_mask(attn_mask, large_tile_size,
block_size)
input_args = (
query.to(device=device),
k.to(device=device),
v.to(device=device),
kv_cache.to(device=device),
active_block_table.to(device=device),
attn_mask.to(device=device),
)
input_kwargs = dict(
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=large_tile_size,
)
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
num_actual_tokens = sum(query_lens)
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.cpu().permute(0, 2, 1, 3)
output_nki = output_nki[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
"constant",
0,
)
output_ref = output_ref_padded.transpose(
0, 1)[0, :num_actual_tokens, :, :]
torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0)

View File

@ -1,68 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [
(16, False, 32, 32, 1024, True),
(16, False, 32, 128, 1024, True),
(16, True, 32, 32, 1024, True),
(16, True, 32, 128, 1024, True),
(16, False, 32, 128, 1024, False),
(16, True, 32, 128, 1024, False),
])
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
head_size, seq_len, use_key):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
batch_size = 1
base = 10000
num_heads = 8
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cpu")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device="cpu")
key = torch.randn_like(query) if use_key else None
assert positions.is_cpu, \
"reference input tensor is expected to be CPU tensor."
ref_query, ref_key = rot.to(device="cpu").forward_native(
positions, query, key)
out_query, out_key = rot.to(device=device).forward_neuron(
positions.to(device=device), query.to(device=device),
key.to(device=device) if key is not None else None)
if use_key:
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_key.cpu(),
ref_key,
atol=1e-2,
rtol=1e-2)
else:
assert out_key is None, "expected returned key to be None"
assert out_query.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_query.cpu(),
ref_query,
atol=1e-2,
rtol=1e-2)

View File

@ -1,101 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Callable
from unittest.mock import patch
import pytest
import torch
import torch_xla.distributed.xla_multiprocessing as xmp
from typing_extensions import ParamSpec
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_distributed_init_method, get_open_port
_P = ParamSpec("_P")
def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to reinitialize the Neuron Runtime before executing a test.
This is necessary for distributed tests which need to reallocate Neuron
Cores to separate subprocesses.
"""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
runtime = torch.classes.neuron.Runtime()
runtime.initialize()
runtime.unsafe_close()
f(*args, **kwargs)
runtime.initialize()
return wrapper
def all_gather_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
total_size = 1
for s in tensor_size:
total_size *= s
all_gather_dimension = -1
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="xla").reshape(tensor_size) * (r + 1)
for r in range(tp_degree)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
torch.testing.assert_close(t, expected)
def all_reduce_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1)
for r in range(tp_degree)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_reduce(t)
torch.testing.assert_close(t, expected)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
@reinitialize_neuron_runtime
def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size,
test_target):
with patch('torch_xla._XLAC._xla_runtime_is_initialized',
return_value=False):
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size))
monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES",
','.join(['1' for _ in range(tp_size)]))
xmp.spawn(test_target, args=(tp_size, distributed_init_method))

View File

@ -1,83 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import shutil
import tempfile
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from vllm import LLM, SamplingParams
def patch_eagle_draft_with_lm_head(target_model_id: str,
draft_model_id: str) -> str:
# In NxDI, draft model checkpoint must include lm_head weights from target
# model. For more details see https://awsdocs-neuron.readthedocs-hosted.com
# /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html
# #eagle-checkpoint-compatibility
final_draft_dir = "/tmp/patched_eagle_draft"
with tempfile.TemporaryDirectory() as tmp_dir:
target_dir = snapshot_download(repo_id=target_model_id,
local_dir=os.path.join(
tmp_dir, "target"))
draft_dir = snapshot_download(repo_id=draft_model_id,
local_dir=os.path.join(tmp_dir, "draft"))
lm_head_key = "lm_head.weight"
index_path = os.path.join(target_dir, "model.safetensors.index.json")
with open(index_path) as f:
index = json.load(f)
shard_name = index["weight_map"][lm_head_key]
target_safetensor_path = os.path.join(target_dir, shard_name)
with safe_open(target_safetensor_path, framework="pt") as f:
target_lm_head = f.get_tensor(lm_head_key)
draft_path = os.path.join(draft_dir, "pytorch_model.bin")
draft_state_dict = torch.load(draft_path, map_location="cpu")
draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16)
torch.save(draft_state_dict, draft_path)
shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True)
return final_draft_dir
def test_eagle():
patched_draft_path = patch_eagle_draft_with_lm_head(
target_model_id="meta-llama/Llama-2-7b-hf",
draft_model_id="yuhuili/EAGLE-llama2-chat-7B")
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
speculative_config={
"model": patched_draft_path,
"num_speculative_tokens": 5,
"max_model_len": 128
},
max_num_seqs=1,
max_model_len=128,
tensor_parallel_size=2,
override_neuron_config={
"enable_eagle_speculation": True,
"enable_fused_speculation": True,
"fused_qkv": True
},
)
prompts = [
"The president of the United States is",
]
outputs = llm.generate(prompts, SamplingParams(top_k=1))
expected_output = " the head of state and head of government of " \
"the United States. The president direct"
for output in outputs:
generated_text = output.outputs[0].text
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
assert (expected_output == generated_text)
print("Neuron Eagle speculation test passed.")

View File

@ -1,64 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
def test_mistral():
llm = LLM(model="mistralai/Mistral-7B-v0.1",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=128,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True
})
# Send more prompts than the compiled batch size (4) and request
# varying generation lengths to test accuracy related to Neuron
# specific sequence id sorting.
prompts = [
"The president of the United States is",
"The capital of France is",
"What is Annapurna labs?",
"I believe the meaning of life is",
"Tell me a story about a brave knight",
"Hello, my name is Llama",
]
sampling_params = [
SamplingParams(top_k=1, max_tokens=10),
SamplingParams(top_k=1, max_tokens=20),
SamplingParams(top_k=1, max_tokens=30),
SamplingParams(top_k=1, max_tokens=40),
SamplingParams(top_k=1, max_tokens=50),
SamplingParams(top_k=1, max_tokens=60)
]
outputs = llm.generate(prompts, sampling_params)
expected_outputs = [
" the most powerful person in the world. He is",
" a city of many faces. It is a city of history, culture, art, "
"fashion, and",
"\n\nAnnapurna Labs is a semiconductor company that was founded "
"in 2013 by Amazon. The company is",
" to be happy.\n\nI believe that happiness is a choice.\n\nI "
"believe that happiness is a state of mind.\n\nI believe that "
"happiness is a journey.\n\nI believe",
" who rescued a princess from a dragon.\n\nTell me a story about"
" a princess who rescued herself from a dragon.\n\nTell me a "
"story about a princess who rescued herself from a dragon and "
"then rescued a knight from",
" and I am a 10 year old male. I am a very friendly and "
"affectionate boy who loves to be around people. I am a very "
"active boy who loves to play and run around. I am a very smart "
"boy who loves to learn new things. I am a very loyal boy"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
assert (expected_output == generated_text)
print("Neuron Mistral test passed.")

View File

@ -1,97 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
def test_llama_single_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=1,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_1])
expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)
def test_llama_multiple_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
override_neuron_config={
"sequence_parallel_enabled":
False,
"skip_warmup":
True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}, {
"name": "lora_id_2",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=2,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
lora_req_2 = LoRARequest("lora_id_2", 1, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_2])
expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)

View File

@ -1,903 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
import torch
from neuronxcc import nki
from neuronxcc.nki.language import par_dim
from vllm.utils import cdiv
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
"""
Load block tables from HBM into SRAM
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
"""
B_P_SIZE = 128
# reshape as `(num_tiles, num_blocks_per_tile)`
assert len(block_tables_hbm.shape) == 1
(num_total_blocks, ) = block_tables_hbm.shape
assert num_blocks_per_tile * num_tiles == num_total_blocks
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_sbuf = nl.zeros(
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
dtype=nl.int32,
)
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(num_blocks_per_tile)[None, :]
block_tables_sbuf[i, i_p, i_f] = nl.load(
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
dtype=nl.int32,
mask=(i_p + i * B_P_SIZE < num_tiles),
)
return block_tables_sbuf
@nki.jit
def transform_block_tables_for_indirect_load(
block_tables,
block_size_tiling_factor,
num_head,
head_id,
):
"""
This function does two things:
1. calculate new `block_tables` for a `head_id` after flattening
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
2. transpose the result so that `block_table` for each tile is mapped to
SBUF Partition dimension for vectorized DMA
Tiling trick to further improve DMA performance:
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
blocks of a given `head_id` from HBM, the load `cache[block_tables,
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
fully utilize hardware parallelization. The solution is to tile `block_size`
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
Note:
We don't further tile D dimension as small DMA size also hurts performance.
"""
B_P_SIZE = 128
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
block_tables.shape)
assert num_tiles_per_partition == B_P_SIZE
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
block_tables_transposed = nl.ndarray(
(
num_loads,
par_dim(B_P_SIZE),
num_partitions * num_tiles_per_partition,
),
dtype=nl.int32,
)
# prepare iota ahead of time to avoid repeatedly using Gpsimd
if num_head > 1:
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
head_id = nl.transpose(
head_id.broadcast_to((1, num_tiles_per_partition)))
if num_blocks_per_tile > 1:
head_id = head_id.broadcast_to(
(num_tiles_per_partition, num_blocks_per_tile))
if block_size_tiling_factor > 1:
broadcast_shape = (
num_tiles_per_partition,
num_blocks_per_tile,
block_size_tiling_factor,
)
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
dtype=nl.int32).broadcast_to(broadcast_shape)
for partition_id in nl.affine_range(num_partitions):
block_tables_partition = block_tables[partition_id]
if num_head > 1:
# fuse num_block and num_head dimension
block_tables_partition = block_tables_partition * num_head + head_id
if block_size_tiling_factor > 1:
# need to apply block size tiling trick
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
block_tables_partition = ((block_tables_partition *
block_size_tiling_factor).reshape(
(num_tiles_per_partition,
num_blocks_per_tile,
1)).broadcast_to(broadcast_shape))
new_block_tables = block_tables_partition + offset
new_block_tables = new_block_tables.reshape(
(num_tiles_per_partition, B_P_SIZE))
else:
new_block_tables = block_tables_partition
# transpose the block table so that it can be used by vector DGE
for i in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = (partition_id * num_tiles_per_partition +
nl.arange(num_tiles_per_partition)[None, :])
block_tables_transposed[i, i_p, i_f] = nl.transpose(
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
return block_tables_transposed
@nki.jit
def load_kv_tile_from_cache(
cur_k_tile,
cur_v_tile,
kv_cache,
block_tables,
large_k_tile_idx,
num_blocks_per_large_tile,
tiled_block_size,
B_P_SIZE,
B_D_SIZE,
):
"""
Load KV cache and transform Key and Value into layout required by Matmul
Vectorized DMA Load layout:
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
Layout used by attention matmuls:
Key: (par_dim(B_D_SIZE), seqlen_kv)
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_k_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
# Transpose SBUF tensor using PE
for tb_i in nl.affine_range(tiled_block_size):
cur_k_tile[
:,
nl.ds(
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
B_P_SIZE,
),
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
# load value cache
for load_idx in nl.affine_range(num_loads):
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
cur_v_tile[
:,
nl.ds(
load_idx * tiled_block_size * B_D_SIZE,
tiled_block_size * B_D_SIZE,
),
] = loaded
@nki.jit
def transpose_p_local(p_local_transposed,
p_local,
LARGE_TILE_SZ,
B_F_SIZE=512):
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.sbuf,
dtype=p_local.dtype)
else:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.psum,
dtype=np.float32)
for j in nl.affine_range(B_F_SIZE // 128):
j_128_slice = nl.ds(j * 128, 128)
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice])
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice])
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
p_local_t_tmp, dtype=p_local_transposed.dtype)
@nki.jit
def _flash_attention_core(
q_local_tile,
k,
v,
o_buffer,
l_buffer,
m_buffer,
kernel_dtype,
acc_type,
tile_mask,
use_causal_mask,
q_tile_idx=None,
initialize=False,
LARGE_TILE_SZ=2048,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
qk_res_buffer=None,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_D_SIZE)
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
be split into size B_F_SIZE tiles
The results are stored in the following three buffers
o_buffer: (B_P_SIZE, d)
l_buffer: (B_P_SIZE, 1)
m_buffer: (B_P_SIZE, 1)
All IO buffers are in SBUF.
"""
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
dtype=acc_type)
for k_i in nl.affine_range(num_k_tile_per_large_tile):
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
if use_causal_mask:
# mask are used to only apply computation to the lower half of the
# matrix, which reduce the arithmetic intensity by up to 50%
multiplication_required_selection = (q_tile_idx * B_P_SIZE
>= k_i * B_F_SIZE)
else:
multiplication_required_selection = True
if multiplication_required_selection:
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True) # (p(128), 512)
qk_res_buf[:, k_i_b_f_slice] = nl.where(
tile_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
else:
qk_res_buf[:, k_i_b_f_slice] = -9984.0
# Calculate max of the current tile
max_local[:, k_i] = nisa.tensor_reduce(
np.max,
qk_res_buf[:, k_i_b_f_slice],
axis=(1, ),
dtype=acc_type,
negate=False,
)
if qk_res_buffer is not None:
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
max_ = nisa.tensor_reduce(
np.max,
max_local[:, :],
axis=(1, ),
dtype=acc_type,
negate=False,
)
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
dtype=o_buffer.dtype)
if initialize:
m_buffer[:, 0] = nl.copy(max_)
m_current = max_
else:
m_previous = nl.copy(m_buffer[:, 0])
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
m_current = m_buffer[:, 0]
# Compute scaling factor
alpha = nisa.activation(
np.exp,
m_previous,
bias=-1 * m_current,
scale=1.0,
)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
p_partial_sum = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
dtype=acc_type,
)
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
# compute exp(qk - max)
# Compute partial row - tile sum of exp(qk - max))
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
np.exp,
qk_res_buf[:, k_r_i_reduce_slice],
bias=-1 * m_current,
scale=1.0,
reduce_op=nl.add,
reduce_res=p_partial_sum[:, k_r_i],
dtype=kernel_dtype,
)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
transpose_p_local(
p_local_transposed=p_local_transposed,
p_local=p_local,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_F_SIZE=B_F_SIZE,
)
pv_psum = nl.zeros(
(par_dim(B_P_SIZE), B_D_SIZE),
dtype=np.float32,
buffer=nl.psum,
)
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
pv_psum[:, :] += nl.matmul(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
transpose_x=True,
) # (128, 128) (p(Br), d)
if initialize:
o_buffer[:, :] = nl.copy(pv_psum[:, :])
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
else:
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
l_prev = l_buffer[:, 0]
l_exp = nl.add(
nl.exp(nl.subtract(l_prev, m_current)),
ps,
)
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
@nki.jit
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
B_P_SIZE = 128
B_D_SIZE = v_hbm_tile.shape[-1]
loaded = nl.load(v_hbm_tile[
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
:,
])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
@nki.jit
def flash_paged_attention(
query,
key,
value,
kv_cache,
block_tables,
mask,
softmax_scale=None,
mixed_precision=True,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
):
"""
Flash PagedAttention Forward Kernel.
IO tensor layouts:
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
- This kernel requires seq_k == seq_v
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
dimension.
- We use paged cache blocks (kv_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_precision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
usage: `flash_fwd[b, h](q, k, v, ...)`
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
"""
B_F_SIZE = 512
B_P_SIZE = 128
b, h, d, seqlen_q = query.shape
B_D_SIZE = d
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
_, num_blocks, k_h, block_size, _ = kv_cache.shape
q_h_per_k_h = h // k_h
assert b == 1, f"invalid batch size {b=}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
cache_shape = (2, num_blocks, k_h, block_size, d)
assert (tuple(kv_cache.shape) == cache_shape
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
assert key is None or tuple(key.shape) == (
1,
k_h,
d,
seqlen_q,
), f"key shape {key.shape} mismatch!"
assert value is None or tuple(value.shape) == (
1,
k_h,
seqlen_q,
d,
), f"value shape {value.shape} mismatch!"
assert (
nl.program_ndim() == 2
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
batch_id = nl.program_id(axis=0)
head_id = nl.program_id(axis=1)
(num_active_blocks, ) = block_tables.shape
context_kv_len = num_active_blocks * block_size
assert (
LARGE_TILE_SZ % B_F_SIZE == 0
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
assert (context_kv_len % LARGE_TILE_SZ == 0
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert is_power_of_2(
num_blocks_per_large_tile
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
if seqlen_q > B_F_SIZE:
MAX_REDUCTION_TILE = 2048
if seqlen_q // 2 > MAX_REDUCTION_TILE:
assert (
seqlen_q % MAX_REDUCTION_TILE == 0
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
else:
assert (seqlen_q % B_F_SIZE == 0
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
softmax_scale = softmax_scale or (1.0 / (d**0.5))
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
o = nl.ndarray((b, h, seqlen_q, d),
dtype=query.dtype,
buffer=nl.shared_hbm)
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
None,
None,
None,
None,
)
if return_debug_tensors:
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
qk_res_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
block_tables_sbuf = load_block_tables(
block_tables_hbm=block_tables,
num_tiles=num_large_k_tile,
num_blocks_per_tile=num_blocks_per_large_tile,
)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
if num_blocks_per_large_tile < B_P_SIZE:
# we checked num_blocks_per_tile is a power of 2
assert B_P_SIZE % num_blocks_per_large_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
# We assume block_size >= block_size_tiling_factor
assert block_size % block_size_tiling_factor == 0
else:
block_size_tiling_factor = 1
tiled_block_size = block_size // block_size_tiling_factor
# Indirect DMA load must be placed along Partition Dimension
block_tables_sbuf = transform_block_tables_for_indirect_load(
block_tables_sbuf,
block_size_tiling_factor=block_size_tiling_factor,
num_head=k_h,
head_id=head_id,
)
# Flatten KV cache to be 3D for loading into SBUF
new_cache_shape = (
2,
num_blocks * k_h * block_size_tiling_factor,
tiled_block_size * d,
)
kv_cache = kv_cache.reshape(new_cache_shape)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
l_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
m_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
cur_k_tile = nl.ndarray(
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype,
)
cur_v_tile = nl.ndarray(
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
dtype=kernel_dtype,
)
load_kv_tile_from_cache(
cur_k_tile=cur_k_tile,
cur_v_tile=cur_v_tile,
kv_cache=kv_cache,
block_tables=block_tables_sbuf,
large_k_tile_idx=large_k_tile_idx,
num_blocks_per_large_tile=num_blocks_per_large_tile,
tiled_block_size=tiled_block_size,
B_P_SIZE=B_P_SIZE,
B_D_SIZE=B_D_SIZE,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
kernel_dtype=kernel_dtype,
acc_type=acc_type,
tile_mask=cur_mask,
use_causal_mask=False,
q_tile_idx=i,
initialize=large_k_tile_idx == 0,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
)
# compute attention between input query, key and value
if key is not None and value is not None:
B_F_SIZE = min(seqlen_q, B_F_SIZE)
LARGE_TILE_SZ = seqlen_q
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
cur_v_tile = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
dtype=kernel_dtype,
)
loaded = nl.load(key[batch_id, head_id, :, :])
if loaded.dtype != kernel_dtype:
loaded = nl.copy(loaded, dtype=kernel_dtype)
cur_k_tile[:, :] = loaded
v_hbm_tile = value[batch_id, head_id]
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
load_v_tile(
v_hbm_tile=v_hbm_tile,
cur_v_tile=cur_v_tile,
large_tile_idx=0,
v_i=v_i,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
kernel_dtype=kernel_dtype,
acc_type=acc_type,
tile_mask=cur_mask,
use_causal_mask=True,
q_tile_idx=i,
initialize=False,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
qk_res_buffer=(qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
out = nl.multiply(
o_buffer[i, i_q_h],
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
dtype=kernel_dtype,
)
nl.store(
o[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
:,
],
out,
)
# maximum and summation statistics
if return_debug_tensors:
nl.store(
hbm_m_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
m_buffer[i, i_q_h, :, :],
)
nl.store(
hbm_l_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
l_buffer[i, i_q_h],
)
nl.store(
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
qk_res_buffer[batch_id, i_q_h, :, :],
)
if return_debug_tensors:
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
return o
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
"""
Reorder the mask to make it compatible with the flash attention kernel.
We vectorize KV cache read to improve DMA utilization. However, the layout
that maximizes DMA bandwidth changes the order tokens are consumed.
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
each step the engine consumes a column (rather than a row) of B_P_SIZE
tokens. Therefore, the tokens are visited in a strided way.
To make sure mask matches the order tokens are consumed, we need to properly
transpose mask.
"""
total_query_len, total_seq_len = mask.shape
context_kv_len = total_seq_len - total_query_len
B_P_SIZE = 128
assert (LARGE_TILE_SZ
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
if tiled_block_size > 1:
# Mask reordering is needed when tiled_block_size > 1
device = mask.device
mask = mask.cpu()
context_mask = mask[:, :context_kv_len]
context_mask = context_mask.view(
total_query_len,
context_kv_len // LARGE_TILE_SZ,
num_tiled_blocks // B_P_SIZE,
B_P_SIZE,
tiled_block_size,
)
context_mask = context_mask.transpose(3, 4).reshape(
total_query_len, context_kv_len)
new_mask = mask[:, context_kv_len:]
return torch.concat([context_mask, new_mask], dim=1).to(device)
else:
return mask
def flash_attn_varlen_nkifunc(
query,
key,
value,
kv_cache,
block_table,
attn_mask,
n_kv_head=None,
head_size=None,
LARGE_TILE_SZ=2048,
mixed_precision=True,
):
"""
Compute flash paged attention for variable length sequences.
This function is a wrapper around the flash attention NKI kernel. It takes
in the following arguments:
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
Notes:
- attn_mask must be reordered outside using `reorder_context_mask`
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
for better DMA throughput
"""
if n_kv_head is None:
n_kv_head = kv_cache.shape[2]
assert kv_cache.shape[0] == 2
assert kv_cache.shape[2] == n_kv_head
if head_size is None:
head_size = kv_cache.shape[-1]
kwargs = dict(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
block_tables=block_table,
mask=attn_mask,
softmax_scale=1.0 / (head_size**0.5),
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
o = flash_paged_attention[1, n_kv_head](**kwargs)
return o
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""
Writes key-value pairs to the KV cache at specified positions.
Args:
key (torch.Tensor): Key tensor with shape
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
kv_cache (torch.Tensor): Key/value cache tensor with shape
(2, num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
Returns:
None: Updates the kv_cache tensor in-place
"""
block_size = kv_cache.size(3)
n_kv_head = key.size(1)
# Calculate indices with explicit floor division
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_offsets = slot_mapping % block_size
# Create the head indices tensor
head_indices = torch.arange(n_kv_head, device=key.device)
# Update caches using index_put_
kv_cache.index_put_(
(torch.tensor([0], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), key)
kv_cache.index_put_(
(torch.tensor([1], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), value)

View File

@ -54,7 +54,6 @@ SystemEnv = namedtuple(
'is_xnnpack_available',
'cpu_info',
'rocm_version', # vllm specific field
'neuron_sdk_version', # vllm specific field
'vllm_version', # vllm specific field
'vllm_build_flags', # vllm specific field
'gpu_topo', # vllm specific field
@ -275,15 +274,6 @@ def get_rocm_version(run_lambda):
r'HIP version: (\S+)')
def get_neuron_sdk_version(run_lambda):
# Adapted from your install script
try:
result = run_lambda(["neuron-ls"])
return result if result[0] == 0 else 'N/A'
except Exception:
return 'N/A'
def get_vllm_version():
from vllm import __version__, __version_tuple__
@ -306,10 +296,9 @@ def get_vllm_version():
def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
return 'CUDA Archs: {}; ROCm: {}'.format(
os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'),
'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled',
'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled',
)
@ -601,7 +590,6 @@ def get_env_info():
conda_packages = get_conda_packages(run_lambda)
rocm_version = get_rocm_version(run_lambda)
neuron_sdk_version = get_neuron_sdk_version(run_lambda)
vllm_version = get_vllm_version()
vllm_build_flags = summarize_vllm_build_flags()
gpu_topo = get_gpu_topo(run_lambda)
@ -635,7 +623,6 @@ def get_env_info():
is_xnnpack_available=is_xnnpack_available(),
cpu_info=get_cpu_info(run_lambda),
rocm_version=rocm_version,
neuron_sdk_version=neuron_sdk_version,
vllm_version=vllm_version,
vllm_build_flags=vllm_build_flags,
gpu_topo=gpu_topo,
@ -702,7 +689,6 @@ env_info_fmt += """
vLLM Info
==============================
ROCM Version : {rocm_version}
Neuron SDK Version : {neuron_sdk_version}
vLLM Version : {vllm_version}
vLLM Build Flags:
{vllm_build_flags}

View File

@ -461,11 +461,6 @@ class ModelConfig:
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`."""
pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
@ -785,10 +780,6 @@ class ModelConfig:
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
if (not current_platform.is_neuron() and self.override_neuron_config):
raise ValueError(
"`override_neuron_config` is only supported on Neuron.")
# Avoid running try_verify_and_update_config multiple times
self.config_updated = False
@ -1696,13 +1687,7 @@ class ModelConfig:
"""
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
True to enable cross-attention
Neuron needs all multimodal data to be in the decoder and does not
need to explicitly enable cross-attention
"""
if (current_platform.is_neuron()
and self.hf_config.model_type == "mllama"):
return False
return is_encoder_decoder(self.hf_config)
@property
@ -1871,7 +1856,7 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"]
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"]
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@ -1927,9 +1912,7 @@ class DeviceConfig:
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
if self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
@ -3941,7 +3924,6 @@ class VllmConfig:
f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
f"tokenizer_mode={self.model_config.tokenizer_mode}, "
f"revision={self.model_config.revision}, "
f"override_neuron_config={self.model_config.override_neuron_config}, " # noqa
f"tokenizer_revision={self.model_config.tokenizer_revision}, "
f"trust_remote_code={self.model_config.trust_remote_code}, "
f"dtype={self.model_config.dtype}, "

View File

@ -33,9 +33,8 @@ class CacheConfig:
"""Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
only block sizes up to 32 are supported.
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current

View File

@ -377,10 +377,7 @@ class ParallelConfig:
from vllm.executor import ray_utils
backend: DistributedExecutorBackend = "mp"
ray_found = ray_utils.ray_is_available()
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
backend = "uni"
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):

View File

@ -1,20 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.platforms import current_platform
if current_platform.is_neuron():
import torch_xla.core.xla_model as xm
class NeuronCommunicator(DeviceCommunicatorBase):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)

View File

@ -419,8 +419,6 @@ class EngineArgs:
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: dict[str, Any] = \
get_field(ModelConfig, "override_neuron_config")
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
ModelConfig.override_pooler_config
compilation_config: CompilationConfig = \
@ -561,8 +559,6 @@ class EngineArgs:
help=model_kwargs["hf_token"]["help"])
model_group.add_argument("--hf-overrides",
**model_kwargs["hf_overrides"])
model_group.add_argument("--override-neuron-config",
**model_kwargs["override_neuron_config"])
model_group.add_argument("--override-pooler-config",
**model_kwargs["override_pooler_config"])
model_group.add_argument("--logits-processor-pattern",
@ -992,7 +988,6 @@ class EngineArgs:
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config,

View File

@ -236,7 +236,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default),
# rocm, neuron, cpu]
# rocm, cpu]
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(),

View File

@ -73,11 +73,6 @@ class CustomOp(nn.Module):
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_neuron(self, *args, **kwargs):
# By default, we assume that Neuron ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
@ -105,8 +100,6 @@ class CustomOp(nn.Module):
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_neuron():
return self.forward_neuron
elif current_platform.is_out_of_tree():
return self.forward_oot
else:

View File

@ -95,13 +95,6 @@ class SiluAndMul(CustomOp):
self.op(out, x)
return out
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
result = s * x_reshaped[:, d:]
return result.view(*x.shape[:-1], d)
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):

View File

@ -26,7 +26,6 @@ QuantizationMethods = Literal[
"bitsandbytes",
"hqq",
"experts_int8",
"neuron_quant",
"ipex",
"quark",
"moe_wna16",
@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .neuron_quant import NeuronQuantConfig
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig
@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"ptpc_fp8": PTPCFp8Config,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,

View File

@ -1,76 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from importlib.util import find_spec
from typing import Any, Optional
from torch.nn import Module
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
class AlwaysSupportedDtypes(list):
def __contains__(self, item):
return True
class NeuronQuantConfig(QuantizationConfig):
"""Int8 Quantization Config class for Neuron Backend."""
def __init__(
self,
dequant_dtype: str = "f16",
quantize_method: str = "vector_dynamic",
) -> None:
super().__init__()
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
raise ValueError(
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
f" the quantization datatype should match one of the below "
f"types {SUPPORTED_QUANT_DTYPE_LIST}")
self.dequant_dtype = dequant_dtype
self.quantize_method = quantize_method
def get_name(self) -> QuantizationMethods:
return "neuron_quant"
def get_supported_act_dtypes(self) -> list[str]:
# Neuron implements custom handling logic for quantization support
return AlwaysSupportedDtypes()
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"This function should not be called with Neuron Backend")
@staticmethod
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
quantize_method = cls.get_from_keys(config, ["quantize_method"])
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
return cls(dequant_dtype=dequant_dtype,
quantize_method=quantize_method)
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
if find_spec("transformers_neuronx") is not None:
return self.get_quantization_config()
else:
raise NotImplementedError(
"Neuron Quantization is only supported through"
" transformers_neuronx.")
def get_quantization_config(self):
from transformers_neuronx.config import QuantizationConfig
return QuantizationConfig(quant_dtype=self.quant_dtype,
dequant_dtype=self.dequant_dtype,
quantize_method=self.quantize_method)

View File

@ -7,7 +7,7 @@ import torch
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
from .common import apply_rotary_emb_torch
@CustomOp.register("rotary_embedding")
@ -149,87 +149,6 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_neuron(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
def _apply_rotary_emb_neuron(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
if offsets is not None:
positions = positions + offsets
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
if self.rotary_dim == self.head_size:
query = apply_rotary_emb_dispatch(query, cos, sin,
self.is_neox_style)
query = query.reshape(query_shape)
if key is not None:
key = apply_rotary_emb_dispatch(key, cos, sin,
self.is_neox_style)
key = key.reshape(key_shape)
else:
head_size = query.shape[-1]
query_reshaped = query.view(-1, head_size)
query_pass = query_reshaped[:, self.rotary_dim:].view(
*query.shape[:-1], head_size - self.rotary_dim)
query_rot = query_reshaped[:, :self.rotary_dim].view(
*query.shape[:-1], self.rotary_dim)
query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass),
dim=-1).reshape(query_shape)
if key is not None:
key_reshaped = key.view(-1, head_size)
key_pass = key_reshaped[:, self.rotary_dim:].view(
*key.shape[:-1], head_size - self.rotary_dim)
key_rot = key_reshaped[:, :self.rotary_dim].view(
*key.shape[:-1], self.rotary_dim)
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"

View File

@ -1,476 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in transformers-neuronx
framework."""
import ast
import copy
import importlib
import os
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logprobs import Logprob
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
"half": "f16",
"float16": "f16",
"bfloat16": "bf16",
"float": "f32",
"float32": "f32",
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
"MistralForSampling", "MistralForCausalLM")
}
class NeuronCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
logits = self.model(input_ids,
cache_ids=positions,
start_ids=input_block_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.on_device_sampling_disabled:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
# On-device sampling outputs the token ids directly.
sampled_token_ids = logits.flatten()
next_tokens = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
samples = []
for seq_id in seq_group.seq_ids:
token_id = sampled_token_ids[sample_idx].item()
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
return SamplerOutput(outputs=next_tokens)
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
**kwargs)
self.model.to_neuron()
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
SPECULATION_TERMINATION_ID = -1
def __init__(self, speculation_model) -> None:
super().__init__()
self.model = speculation_model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
tokens, counts = self.model.speculative_iteration(
input_ids, positions, input_block_ids)
# Mark the end of accepted speculative tokens for each sequence with the
# speculation termination id.
batch_size, steps = tokens.shape
mask = torch.arange(steps).expand(batch_size, -1) >= counts
tokens[mask] = self.SPECULATION_TERMINATION_ID
return tokens
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == self.SPECULATION_TERMINATION_ID
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_buckets(env: str, default_value: list[int]) -> list[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config based on vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
quant_config = dict(
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
quantize_method="vector_dynamic")
neuron_quantization_config_builder = lambda quant: get_quantization_config(
quant).from_config(quant_config).get_quant_method(None, "")
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args = dict(
collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization),
on_device_generation=_get_neuron_on_device_generation_config(
model_config))
return default_neuron_args
def _get_default_neuron_config_for_speculation(
model_config: ModelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config for speculative decoding based on
vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
default_neuron_args = dict(collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
on_device_embedding=True,
continuous_batching=continuous_batching_config,
on_device_generation=copy.deepcopy(
model_config.neuron_sampling_params))
return default_neuron_args
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
return None
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
return not getattr(model_config, "neuron_sampling_params", None)
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import (ContinuousBatchingConfig,
GenerationConfig,
KVCacheQuantizationConfig,
NeuronConfig, QuantizationConfig,
SparseAttnConfig)
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
**sparse_attn)
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
if kv_cache_quant:
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
**kv_cache_quant)
continuous_batching = overridden_neuron_config.pop("continuous_batching",
{})
if continuous_batching:
overridden_neuron_config[
"continuous_batching"] = ContinuousBatchingConfig(
**continuous_batching)
quant = overridden_neuron_config.pop("quant", {})
if quant:
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
on_device_generation = overridden_neuron_config.pop(
"on_device_generation", {})
if on_device_generation:
overridden_neuron_config["on_device_generation"] = GenerationConfig(
**on_device_generation)
default_neuron_config.update(overridden_neuron_config)
return NeuronConfig(**default_neuron_config)
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
# Create a model instance.
model = NeuronCausalLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
model.load_weights(model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This method is only applicable for speculation with a standalone draft model
"""
from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder
# For Eagle SD, we need to pass in additional parameters in neuron config.
is_eagle = getattr(speculation_config.draft_model_config.hf_config,
"is_eagle", False)
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
if is_eagle:
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
if is_eagle:
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
num_speculative_tokens = speculation_config.num_speculative_tokens
# Create speculation model instance.
speculation_model = FusedSpeculativeDecoder(draft_model.model,
target_model.model,
num_speculative_tokens)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)
def get_neuron_eagle_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized EAGLE speculation model for inference."""
from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
token_tree: dict[int, list[int]] = ast.literal_eval(
speculation_config.speculative_token_tree)
speculation_model = EagleSpeculativeDecoder(draft_model.model,
target_model.model,
token_tree=token_tree)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)

View File

@ -1,685 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in
neuronx-distributed-inference framework."""
# Disabling yapf because yapf and isort have conflicts for the below imports
# yapf: disable
import copy
import hashlib
import importlib
import multiprocessing
import os
import shutil
from typing import Optional
import torch
import torch.nn as nn
from neuronx_distributed_inference.models.config import (
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
from neuronx_distributed_inference.models.mllama.utils import (
create_vision_mask)
from neuronx_distributed_inference.modules.lora_serving import (
LoraServingConfig)
from neuronx_distributed_inference.utils.hf_adapter import (
load_pretrained_config)
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput
# yapf: enable
logger = init_logger(__name__)
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "float32",
"half": "float16",
"float16": "float16",
"bfloat16": "bfloat16",
"float": "float32",
"float32": "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.float32: "float32",
}
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
"LlamaForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"MistralForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"DbrxForCausalLM":
("neuronx_distributed_inference.models.dbrx.modeling_dbrx",
"NeuronDbrxForCausalLM"),
"MixtralForCausalLM":
("neuronx_distributed_inference.models.mixtral.modeling_mixtral",
"NeuronMixtralForCausalLM"),
"MllamaForConditionalGeneration":
("neuronx_distributed_inference.models.mllama.modeling_mllama",
"NeuronMllamaForCausalLM"),
}
class NeuronCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
prev_hidden: Optional[torch.Tensor] = None,
adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params,
prev_hidden=prev_hidden,
adapter_ids=adapter_ids)
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
else:
output = output.logits[:, -1, :]
restored_indices = torch.argsort(sorted_indices)
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
batch_size = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
assert len(seq_ids) == list(batch_size)[0], "batch size mismatch"
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.flatten()
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
step_output_token_ids = []
for i, seq_id in enumerate(seq_ids):
token_id = accepted_token_ids_by_step[i]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
return SamplerOutput(outputs=step_output_token_ids)
else:
return self.sampler(logits, sampling_metadata)
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
class NeuronMllamaForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
# has_image is the only multimodal input that is used in
# token-generation
# This is a cache (on CPU) that saves has_image data per sequence id
# The number of entries in this cache is <= Batch-Size
self.has_image_cache: dict[int, torch.Tensor] = {}
self.config = config
self.logits_processor = LogitsProcessor(
config.get_text_config().vocab_size, logits_as_input=True)
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
self.is_reorder_needed: bool = True
def read_from_has_image_cache(self, seq_ids: torch.Tensor):
has_image_list = []
for index in range(len(seq_ids)):
seq_id = seq_ids[index].item()
if seq_id in self.has_image_cache:
has_image_list.append(self.has_image_cache[seq_id])
else:
has_image_list.append(torch.tensor([0]))
return torch.tensor(has_image_list)
def write_to_has_image_cache(self, seq_ids: torch.Tensor,
has_image: torch.Tensor):
for index in range(len(seq_ids)):
seq_id = seq_ids[index].item()
if index < len(has_image):
self.has_image_cache[seq_id] = has_image[index]
else:
self.has_image_cache[seq_id] = torch.zeros(1)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
# We update the has_image cache during prefill
# and read the has_image cache during decode
if input_ids.shape[-1] > 1: # prefill
self.write_to_has_image_cache(seq_ids, has_image)
else:
has_image = self.read_from_has_image_cache(seq_ids)
bs = input_ids.shape[0]
num_chunks = torch.zeros((bs, 1))
aspect_ratios = torch.zeros((bs, 1, 2))
input_block_ids = seq_ids
origin_input_block_ids = seq_ids
if self.is_reorder_needed:
# sort block ids sequentially for perf/neuron support reasons
input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
pixel_values = torch.index_select(pixel_values, 0, sorted_indices)
aspect_ratios = torch.index_select(aspect_ratios, 0,
sorted_indices)
num_chunks = torch.index_select(num_chunks, 0, sorted_indices)
has_image = torch.index_select(has_image, 0, sorted_indices)
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
output = self.model(
input_ids.to(torch.int32),
attention_mask=None,
position_ids=positions.to(torch.int32),
seq_ids=seq_ids.flatten().to(torch.int32),
pixel_values=pixel_values.to(
self.config.vision_config.torch_dtype),
aspect_ratios=aspect_ratios.to(torch.int32),
vision_mask=self.vision_mask.to(torch.int32),
sampling_params=sampling_params,
num_chunks=num_chunks.to(torch.int32),
has_image=has_image.to(torch.int32),
)
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
else:
output = output.logits[:, -1, :]
if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1:
restored_indices = torch.argsort(sorted_indices)
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(self, hidden_states, sampling_metadata):
if not self.on_device_sampling_disabled:
with torch.profiler.record_function("sample"):
hidden_states = hidden_states.flatten()
res = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
samples = []
for seq_id in seq_ids:
token_id = hidden_states[sample_idx].item()
samples.append(
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
res.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
next_tokens = SamplerOutput(outputs=res)
else:
next_tokens = self.sampler(None, hidden_states, sampling_metadata)
return next_tokens
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
logger.info("neuron_config buckets: %s",
self.config.neuron_config.buckets)
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
try:
self.model = neuronx_model_cls(compiled_model_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.vision_token_id = tokenizer(
"<|image|>", add_special_tokens=False).input_ids[0]
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError):
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
logger.info("\nCompiling and saving model to %s", model_name_or_path)
p = multiprocessing.Process(target=compile_model,
args=(self, compiled_model_path))
p.start()
p.join()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(compiled_model_path)
logger.info("Successfully compiled and saved the model in %s",
compiled_model_path)
# Read "<|image|>" token_id from the tokenizer
self.vision_token_id = tokenizer("<|image|>",
add_special_tokens=False).input_ids[0]
logger.info("\nLoading model from compiled checkpoint...")
self.model.load(compiled_model_path)
def compile_model(neuron_model, traced_model_path):
neuron_model.model.compile(traced_model_path)
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params)
restored_indices = torch.argsort(sorted_indices)
# CTX encoding
if (positions[:, 0]).sum().item() == 0:
output = output.fused_outputs[0][:, 0:1]
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
# Fused Spec (Generation)
accepted_tokens_with_padding = output.fused_outputs[0]
next_pos_ids = output.fused_outputs[-1]
generated_token_counts = next_pos_ids - positions
assert torch.any(generated_token_counts == 0).item() is False, \
"NxDI model generated no output for one or more sequences."
batch_size, steps = accepted_tokens_with_padding.shape
mask = torch.arange(steps).expand(batch_size,
-1) >= generated_token_counts
accepted_tokens_with_padding[mask] = -1
if input_block_ids.shape[0] != 1:
accepted_tokens_with_padding = torch.index_select(
accepted_tokens_with_padding, 0, restored_indices)
return accepted_tokens_with_padding
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def load_weights(self, model_name_or_path: str,
draft_model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
draft_neuron_config = copy.deepcopy(config.neuron_config)
if not config.neuron_config.enable_eagle_speculation:
draft_neuron_config.speculation_length = 0
draft_neuron_config.trace_tokengen_model = True
draft_neuron_config.enable_fused_speculation = False
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
None):
draft_neuron_config.modules_to_not_convert = (
draft_neuron_config.draft_model_modules_to_not_convert)
if config.neuron_config.enable_eagle_speculation:
draft_neuron_config.is_eagle_draft = True
draft_neuron_config.sequence_parallel_enabled = False
draft_config = neuronx_model_cls.get_config_cls()(
draft_neuron_config,
load_config=load_pretrained_config(draft_model_name_or_path))
fused_spec_config = (FusedSpecNeuronConfig(
neuronx_model_cls._model_cls,
draft_config=draft_config,
draft_model_path=draft_model_name_or_path))
config.fused_spec_config = fused_spec_config
self.config.neuron_config = neuron_config
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
if not os.path.exists(draft_model_name_or_path):
if draft_model_name_or_path != model_name_or_path:
hf_model = AutoModelForCausalLM.from_pretrained(
draft_model_name_or_path)
saved_path = os.path.join("local-models",
draft_model_name_or_path)
hf_model.save_pretrained(saved_path)
draft_model_name_or_path = saved_path
else:
draft_model_name_or_path = model_name_or_path
config.fused_spec_config.draft_model_path = draft_model_name_or_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig):
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
deterministic=False)
batch_size = scheduler_config.max_num_seqs
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=batch_size,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
padding_side="right",
on_device_sampling_config=on_device_sampling_config,
sequence_parallel_enabled=True,
lora_serving_config=lora_serving_config)
return neuron_config
def _get_default_speculation_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Generate a neuron config for speculative decoding based on vllm config
args."""
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=scheduler_config.max_num_seqs,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
speculation_length=speculation_config.num_speculative_tokens,
trace_tokengen_model=False,
enable_fused_speculation=True,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
on_device_sampling_config=dict(
top_k=1,
do_sample=False,
))
return neuron_config
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
"""Update default neuron config values with override args"""
overridden_neuron_config = overridden_neuron_config or {}
default_neuron_config.update(overridden_neuron_config)
return default_neuron_config
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
model_arch = _get_model_architecture(model_config.hf_config)
if model_arch == "MllamaForConditionalGeneration":
model = NeuronMllamaForCausalLM(model_config.hf_config)
else:
model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config, lora_serving_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This model handles speculation using both a draft model and an EAGLE draft.
"""
model = NeuronSpeculationCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_speculation_config(
model_config, parallel_config, scheduler_config, speculation_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
speculation_config.draft_model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()

View File

@ -169,37 +169,12 @@ def cpu_platform_plugin() -> Optional[str]:
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
def neuron_platform_plugin() -> Optional[str]:
tnx_installed = False
nxd_installed = False
logger.debug("Checking if Neuron platform is available.")
try:
import transformers_neuronx # noqa: F401
tnx_installed = True
logger.debug("Confirmed Neuron platform is available because"
" transformers_neuronx is found.")
except ImportError:
pass
try:
import neuronx_distributed_inference # noqa: F401
nxd_installed = True
logger.debug("Confirmed Neuron platform is available because"
" neuronx_distributed_inference is found.")
except ImportError:
pass
is_neuron = tnx_installed or nxd_installed
return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
builtin_platform_plugins = {
'tpu': tpu_platform_plugin,
'cuda': cuda_platform_plugin,
'rocm': rocm_platform_plugin,
'xpu': xpu_platform_plugin,
'cpu': cpu_platform_plugin,
'neuron': neuron_platform_plugin,
}

View File

@ -73,7 +73,6 @@ class PlatformEnum(enum.Enum):
TPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
OOT = enum.auto()
UNSPECIFIED = enum.auto()
@ -164,9 +163,6 @@ class Platform:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT

View File

@ -1,151 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import os
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
class NeuronFramework(enum.Enum):
TRANSFORMERS_NEURONX = "transformers-neuronx"
NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference"
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_name: str = "neuron"
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
dist_backend: str = "gloo"
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "neuron"
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "uni"
if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
return False
@classmethod
def get_device_communicator_cls(cls) -> str:
if envs.VLLM_USE_V1:
return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa
else:
return Platform.get_device_communicator_cls()
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
@lru_cache
def is_neuronx_distributed_inference(cls) -> bool:
try:
import neuronx_distributed_inference
except ImportError:
neuronx_distributed_inference = None
return neuronx_distributed_inference is not None
@classmethod
@lru_cache
def is_transformers_neuronx(cls) -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
def get_neuron_framework_to_use(self):
"""Return the specified framework if corresponding installations are
available.
If no framework is specified, use neuronx-distributed-inference by
default.
If that's unavailable, check and switch to transformers-neuronx.
"""
if not self.is_neuron():
raise AssertionError(
f"Neuron Framework unavailable for platform: {self}")
tnx_installed = self.is_transformers_neuronx()
nxd_installed = self.is_neuronx_distributed_inference()
specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK")
tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value
if specified_framework == tnx_framework and tnx_installed:
return self.TRANSFORMERS_NEURONX
if ((specified_framework == nxd_framework and nxd_installed)
or (specified_framework is None and nxd_installed)):
return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
if specified_framework is None and tnx_installed:
return NeuronFramework.TRANSFORMERS_NEURONX
return None
def use_neuronx_distributed(self):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This
is used to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
return self.get_neuron_framework_to_use() == nxd_framework
def use_transformers_neuronx(self):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used
to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
return self.get_neuron_framework_to_use(
) == NeuronFramework.TRANSFORMERS_NEURONX

View File

@ -1,455 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from vllm.config import DeviceConfig, VllmConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: SamplingMetadata = None
multi_modal_kwargs: BatchedTensorInputs = None
adapter_ids: Optional[str] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
return {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"input_block_ids": self.input_block_ids,
"sampling_metadata": self.sampling_metadata,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForNeuron":
return ModelInputForNeuron(
input_tokens=tensor_dict["input_tokens"],
input_positions=tensor_dict["input_positions"],
input_block_ids=tensor_dict["input_block_ids"],
sampling_metadata=tensor_dict["sampling_metadata"],
multi_modal_kwargs=tensor_dict["multi_modal_kwargs"],
)
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
"""A model runner for AWS Neuron hardware"""
# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K = 256
def __init__(
self,
vllm_config: VllmConfig,
):
ModelRunnerBase.__init__(self, vllm_config)
if (self.model_config is not None
and self.model_config.get_sliding_window()):
logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.")
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
self.lora_config = vllm_config.lora_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
# Lazy initialization.
self.model: nn.Module # initialize after load_model.
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
# turn off on-device sampling.
self._on_device_sampling_disabled = int(
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))
# NEURON needs to update sampling parameters when request IDs change
# across batches. This variable stores the previous batch's request IDs
# to determine if an update is needed.
self._previous_batch_request_ids: List[str] = []
if not self._on_device_sampling_disabled:
self._init_neuron_sampling()
def _init_neuron_sampling(self) -> None:
if current_platform.use_transformers_neuronx():
from transformers_neuronx.config import GenerationConfig
else:
from transformers import GenerationConfig
logger.warning(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.")
self.model_config.neuron_sampling_params = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def get_model(self) -> nn.Module:
return self.model
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
seq_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
seq_len = len(prompt_tokens)
seq_lens.append(seq_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(seq_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
mm_kwargs = seq_group_metadata.multi_modal_data
if mm_kwargs:
mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs)
multi_modal_kwargs_list.append(mm_kwargs)
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
context_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = None
if not self._on_device_sampling_disabled:
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
top_k, top_p, temperature = (
self._convert_to_neuron_sampling_params(sampling_params))
sampling_params.top_k = top_k
sampling_params.top_p = top_p
sampling_params.temperature = temperature
# we need multi_modal_data for later tokens as well
multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list:
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
multi_modal_kwargs_list.append(mm_data)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
if current_platform.use_transformers_neuronx(
) and not self._on_device_sampling_disabled:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids = [
seq_group_meta_data.request_id
for seq_group_meta_data in seq_group_metadata_list
]
if current_batch_request_ids != self._previous_batch_request_ids:
self._update_neuron_sampling_params(seq_group_metadata_list)
self._previous_batch_request_ids = current_batch_request_ids
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)
def _update_neuron_sampling_params(
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params = self.model_config.neuron_sampling_params
assert current_sampling_params is not None, (
f"Failed to update sampling_params, "
f"current sampling params is {current_sampling_params}")
is_update_needed = False
top_k = current_sampling_params.top_k
top_p = current_sampling_params.top_p
temperature = current_sampling_params.temperature
# The index of a sequence's sampling parameters in neuron is equal to
# its index in `input_block_ids`.
for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_group_top_k = sampling_params.top_k
seq_group_top_p = sampling_params.top_p
seq_group_temperature = sampling_params.temperature
for seq_id in seq_ids:
index = seq_group_metadata.block_tables[seq_id][0]
if (top_k[index] != seq_group_top_k
or top_p[index] != seq_group_top_p
or temperature[index] != seq_group_temperature):
is_update_needed = True
top_k[index] = seq_group_top_k
top_p[index] = seq_group_top_p
temperature[index] = seq_group_temperature
# update_generation_config is only available in transformers-neuronx
if is_update_needed and current_platform.use_transformers_neuronx():
self.model.model.update_generation_config(current_sampling_params)
def _convert_to_neuron_sampling_params(
self, sampling_params: SamplingParams) -> Tuple[int, float, float]:
# Returns the top_k, top_p and temperature parameters for neuron.
top_k = sampling_params.top_k
top_p = sampling_params.top_p
temperature = sampling_params.temperature
if temperature == 0.0:
# Enable greedy sampling on zero temperature
return (1, 1.0, 1.0)
if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
top_k = self._MAX_NEURON_SAMPLING_TOP_K
return (top_k, top_p, temperature)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
# extract top_k, top_p and temperature from model_input for neuron
# forward call
sampling_params = (torch.tensor([[
seq_group.sampling_params.top_k, seq_group.sampling_params.top_p,
seq_group.sampling_params.temperature
] for seq_group in model_input.sampling_metadata.seq_groups]))
if current_platform.use_neuronx_distributed():
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
adapter_ids=model_input.adapter_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
elif current_platform.use_transformers_neuronx():
# [TODO] validate on-device sampling
# The model signature may need change for on-device sampling
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
if self._on_device_sampling_disabled:
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
else:
logits = hidden_states
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def process_multi_modal_data_neuron(self, mm_data):
# this is a no-op for NeuronModelRunner
return mm_data
def remove_all_loras(self):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def add_lora(self, lora_request: LoRARequest):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def list_loras(self) -> Set[int]:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

View File

@ -1,189 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A Neuron worker class."""
import os
from typing import List, Optional, Set, Tuple
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class NeuronWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""
model_runner: NeuronModelRunner
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
self.lora_config = vllm_config.lora_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
neuron_framework = current_platform.get_neuron_framework_to_use()
if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX:
self.model_runner = self.get_tnx_model_runner(vllm_config)
elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE:
self.model_runner = self.get_neuronx_distributed_model_runner(
vllm_config)
else:
raise NotImplementedError(
"Specified framework" +
f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" +
" is either not installed or not supported." +
" Supported frameworks: " +
"[transformers-neuronx, neuronx-distributed-inference]")
def get_tnx_model_runner(self, vllm_config):
assert (self.lora_config
is None), ("LoRA is not supported for TransformersNeuronX "
"framework.")
if self.speculative_config is not None:
raise NotImplementedError(
"Speculative decoding is not supported for TransformersNeuronX"
)
return NeuronModelRunner(vllm_config=vllm_config)
def get_neuronx_distributed_model_runner(self, vllm_config):
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
if self.speculative_config is not None:
assert (self.lora_config is None), (
"LoRA is not supported for Speculative Decoding")
raise NotImplementedError(
"Speculative decoding is not supported for NeuronxDistributed")
return NeuronxDistributedModelRunner(vllm_config=vllm_config)
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
num_gpu_blocks = self.scheduler_config.max_num_seqs + 1
# Swap not yet supported with Neuron backend.
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache.
"""
# Different values are not tested.
assert num_cpu_blocks == 0
assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@property
def do_metadata_broadcast(self) -> bool:
return False
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
return WorkerInput(num_seq_groups=len(
execute_model_req.seq_group_metadata_list), )
def execute_worker(self, worker_input: WorkerInput) -> None:
pass
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment initialized when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
1,
1,
)
def add_lora(self, lora_request: LoRARequest) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.list_loras()

View File

@ -1,294 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional, Set
import torch
from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import (
get_all_supported_aspect_ratios)
from neuronx_distributed_inference.modules.generation.sampling import (
prepare_sampling_params)
from neuronx_distributed_inference.modules.lora_serving import (
LoraCheckpoint, LoraServingConfig)
from vllm.config import VllmConfig
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuronx_distributed import (
_get_model_architecture, get_neuron_model)
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
NeuronModelRunner)
logger = init_logger(__name__)
class NeuronxDistributedModelRunner(NeuronModelRunner):
def __init__(
self,
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
self.lora_checkpoint = None
self.model = None
self.lora_serving_config = None
@staticmethod
def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]):
if not lora_modules:
return None
return {_.get("name"): _.get("path") for _ in lora_modules}
def _get_nxdi_lora_config(self):
override_neuron_config = self.model_config.override_neuron_config
lora_modules = override_neuron_config.pop("lora_modules", None)
target_modules = override_neuron_config.pop("target_modules", None)
lora_ckpt_paths = self._get_lora_paths_strings(lora_modules)
if self.lora_config.max_loras < len(lora_ckpt_paths):
raise ValueError(
"Number of LoRAs (%s) exceeds maximum "
"allowed (%s)", len(lora_ckpt_paths),
self.lora_config.max_loras)
return LoraServingConfig(
max_loras=self.lora_config.max_loras,
max_lora_rank=self.lora_config.max_lora_rank,
target_modules=target_modules,
lora_ckpt_paths=lora_ckpt_paths,
)
def load_model(self) -> None:
# Update LoRA config
if self.lora_config is not None:
self.lora_serving_config = self._get_nxdi_lora_config()
self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config)
self.model = get_neuron_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
lora_serving_config=self.lora_serving_config)
def get_nxd_sampling_params(self, sampling_metadata):
if self.model.config.neuron_config.on_device_sampling_config:
max_topk = (self.model.config.neuron_config.
on_device_sampling_config.global_topk)
else:
max_topk = self.model.config.vocab_size
top_k = [1] * self.scheduler_config.max_num_seqs
top_p = [1.0] * self.scheduler_config.max_num_seqs
temperature = [1.0] * self.scheduler_config.max_num_seqs
for index, sequenceGroupToSample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = (sequenceGroupToSample.sampling_params.top_k
if sequenceGroupToSample.sampling_params.top_k > 0
else max_topk)
top_p[index] = sequenceGroupToSample.sampling_params.top_p
temperature[index] = (
sequenceGroupToSample.sampling_params.temperature)
sampling_params = prepare_sampling_params(
batch_size=self.scheduler_config.max_num_seqs,
top_k=top_k,
top_p=top_p,
temperature=temperature)
return sampling_params
def get_multi_modal_data_neuron(self, input_images):
raise NotImplementedError("need to restore multi-modal support")
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
if _get_model_architecture(
self.model.config) != "MllamaForConditionalGeneration":
return super().execute_model(model_input, kv_caches,
intermediate_tensors, num_steps)
sampling_params = self.get_nxd_sampling_params(
model_input.sampling_metadata)
if model_input.multi_modal_kwargs.get('pixel_values') is not None:
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
seq_ids=model_input.input_block_ids,
pixel_values=model_input.multi_modal_kwargs.get(
'pixel_values'),
aspect_ratios=model_input.multi_modal_kwargs.get(
'aspect_ratios'),
sampling_params=sampling_params,
num_chunks=model_input.multi_modal_kwargs.get('num_chunks'),
has_image=model_input.multi_modal_kwargs.get(
'has_image').squeeze(1),
)
else:
bs = model_input.input_tokens.shape[0] if (model_input.input_tokens
is not None) else 1
empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560],
dtype=torch.bfloat16)
empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64)
num_chunks = torch.zeros((bs, 1), dtype=torch.int32)
has_image = torch.zeros([bs], dtype=torch.int32)
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
seq_ids=model_input.input_block_ids,
pixel_values=empty_pixel_values,
aspect_ratios=empty_aspect_ratios,
sampling_params=sampling_params,
num_chunks=num_chunks,
has_image=has_image,
)
output = self.model.sample(
hidden_states=hidden_states,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
def process_multi_modal_data_neuron(self, mm_data):
# Neuron uses aspect_ratios instead of aspect_ratio_ids
all_supported_aspect_ratios = get_all_supported_aspect_ratios(
self.model.config.vision_config.max_num_tiles)
aspect_ratio_ids = mm_data.get("aspect_ratio_ids")
mm_data["aspect_ratios"] = torch.tensor(
all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0)
# Neuron's num_chunks is HF's num_tiles
mm_data["num_chunks"] = mm_data.get("num_tiles")
# Input has an image if it has pixel_values
bs = mm_data["num_chunks"].shape[0]
pixel_values = mm_data.get("pixel_values")
if pixel_values is not None and not torch.all(pixel_values == 0):
mm_data["has_image"] = torch.ones(bs)
else:
mm_data["has_image"] = torch.zeros(bs)
return mm_data
def _get_lora_adapter_ids(self, seq_group_metadata_list):
# set LoRA adapter IDs for multi-lora serving
batch_size = len(seq_group_metadata_list)
if self.lora_checkpoint is not None:
# "0" indicates NxDI to use the base model for inference
adapter_ids = ["0"] * batch_size
for idx, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.lora_request is not None:
adapter_ids[
idx] = seq_group_metadata.lora_request.lora_name
# convert adapter_ids from strings to integers
adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices(
adapter_ids, batch_size)
else:
adapter_ids = torch.zeros((batch_size), dtype=torch.int32)
return adapter_ids
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = None
if not self._on_device_sampling_disabled:
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
top_k, top_p, temperature = (
self._convert_to_neuron_sampling_params(sampling_params))
sampling_params.top_k = top_k
sampling_params.top_p = top_p
sampling_params.temperature = temperature
# we need multi_modal_data for later tokens as well
multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list:
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
multi_modal_kwargs_list.append(mm_data)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs,
adapter_ids=lora_adapter_ids)
def remove_all_loras(self):
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def add_lora(self, lora_request: LoRARequest):
logger.warning(
"Adding LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config. If you supplied "
"the parameter, you can ignore this warning. Ignoring"
"lora request: ", lora_request)
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def list_loras(self) -> Set[int]:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")