[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend (#14238)

Signed-off-by: Akshat Tripathi <akshat@krai.ai>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Akshat Tripathi
2025-05-07 21:28:47 +01:00
committed by GitHub
parent db593aa67f
commit c20ef40fd0
19 changed files with 929 additions and 46 deletions

View File

@ -50,6 +50,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
&& echo TEST_12 \
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
# Disable the TPU LoRA tests until the feature is activated
# && echo TEST_13 \
# && pytest -s -v /workspace/vllm/tests/tpu/lora/" \
# TODO: This test fails because it uses RANDOM_SEED sampling

View File

@ -47,7 +47,7 @@ def dist_init():
temp_file = tempfile.mkstemp()[1]
backend = "nccl"
if current_platform.is_cpu():
if current_platform.is_cpu() or current_platform.is_tpu():
backend = "gloo"
init_distributed_environment(world_size=1,

View File

124
tests/tpu/lora/test_lora.py Normal file
View File

@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import vllm
from vllm.lora.request import LoRARequest
# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.
# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.
@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
"""
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
for all tests in this file
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
yield
def setup_vllm(num_loras: int) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
num_scheduler_steps=1,
max_model_len=256,
max_seq_len_to_capture=256,
max_num_seqs=8,
enable_lora=True,
max_loras=num_loras,
max_lora_rank=8)
def test_single_lora():
"""
This test ensures we can run a single LoRA adapter on the TPU backend.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=1.
"""
llm = setup_vllm(1)
prompt = "What is 1+1? \n"
lora_request = LoRARequest(
"lora_adapter_1", 1,
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(max_tokens=256,
temperature=0),
lora_request=lora_request)[0].outputs[0].text
answer = output.strip()[0]
assert answer.isdigit()
assert int(answer) == 1
def test_lora_hotswapping():
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, even
if we only have space to store 1.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""
lora_name_template = \
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]
llm = setup_vllm(1)
prompt = "What is 1+1? \n"
for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
lora_request=req)[0].outputs[0].text
answer = output.strip()[0]
assert answer.isdigit()
assert int(answer) == i + 1
def test_multi_lora():
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, when
we have enough space to store all of them.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""
lora_name_template = \
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]
llm = setup_vllm(4)
prompt = "What is 1+1? \n"
for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
lora_request=req)[0].outputs[0].text
answer = output.strip()[0]
assert answer.isdigit()
assert int(output.strip()[0]) == i + 1

View File

@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
N_TOKENS = [16, 1024, 4096]
HIDDEN_SIZES = [1024, 2048, 4096]
DTYPES = [torch.bfloat16]
NUM_LORA = [1, 4, 16]
RANKS = [32, 256, 512]
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
"""
Inputs: (All integers)
T: Total number of tokens
D: Input dim
L: LoRA Dim
N: N LoRAs
Outputs:
inputs: torch.Tensor - shape (T, D)
loras: torch.Tensor - shape (N, 1, L, D)
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
"""
torch.manual_seed(seed)
inputs = torch.randn((T, D), device="xla", dtype=dtype)
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
ref_output = ref_bgmv(inputs, loras, idxs)
return inputs, loras, idxs, ref_output
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
selected_loras = loras[idxs]
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(axis=1)
batch_size, output_size, input_size = selected_loras.shape
return (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))
# Parameterize tests with various shapes and dtypes
@pytest.mark.parametrize("T", N_TOKENS)
@pytest.mark.parametrize("D", HIDDEN_SIZES)
@pytest.mark.parametrize("L", RANKS)
@pytest.mark.parametrize("N", NUM_LORA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", [0])
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
if op_type == "expand":
D, L = L, D
inputs, loras, idxs, ref_output = generate_test_data(
T, D, L, N, seed, dtype)
# Run bgmv
output = torch.ops.xla.bgmv(inputs, loras, idxs)
# Make sure we have no NaNs
assert not torch.any(torch.isnan(output))
# Compare with reference output
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)

View File

@ -2694,8 +2694,8 @@ class LoRAConfig:
lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
lora_vocab_padding_size: ClassVar[int] = current_platform\
.get_lora_vocab_padding_size()
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
@ -2723,6 +2723,7 @@ class LoRAConfig:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size)
factors.append(self.long_lora_scaling_factors)
factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode(),

View File

@ -16,6 +16,7 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.platforms import current_platform
if TYPE_CHECKING:
pass
@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
device=x.device,
)
layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink(
buffers, x, layer.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffers = shrunk_buffers
buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand(output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)
lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand(
output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
@ -292,7 +303,11 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
device=x.device,
)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink(
buffer, x, self.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
@ -304,7 +319,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
self.punica_wrapper.add_expand(
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
@ -313,6 +328,10 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
offset_start=offset_start,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
return output

View File

@ -261,10 +261,17 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
full_lora_a_embeddings.shape[1],
-1,
)
self.punica_wrapper.add_lora_embedding(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
if not current_platform.can_update_inplace():
full_output = lora_output
return full_output.view_as(full_output_org)
@classmethod
@ -410,10 +417,13 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
output = output.flatten(0, 1)
x = x.flatten(0, 1)
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked, 1.0,
self.output_slices)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked,
self.lora_bias_stacked, 1.0, self.output_slices)
if not current_platform.can_update_inplace():
output = lora_output
return output
@property
@ -1133,15 +1143,23 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
lora_logits[-1] = float("-inf")
neg_inf, pos_inf = current_platform.get_infinity_values(
lora_logits.dtype)
lora_logits[-1] = neg_inf
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded
if current_platform.is_tpu():
indices_padded = indices_padded[:logits.size(0)]
lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
posinf=pos_inf,
neginf=neg_inf))
# HPU needs special handling to prune out dummy samples.
if current_platform.is_hpu():
@ -1151,10 +1169,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
# LogitsProcessorWithLoRA always using bgmv
self.punica_wrapper.add_lora_logits(logits, hidden_states,
self.lora_a_stacked,
self.lora_b_stacked, 1.0)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_logits(
logits, hidden_states, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
if not current_platform.can_update_inplace():
logits = lora_output
# Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size]

View File

@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink)
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
import torch
# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
outputs = torch.cat(
(outputs,
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
device=outputs.device)),
dim=1)
if add_inputs:
return output_tensor + outputs[:limit, :]
else:
return outputs[:limit, :]
def bgmv_shrink(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
scaling (float, optional): Scalar multiplier applied to the output.
"""
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights,
lora_indices_tensor)
def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
n_tokens = outputs.size(0)
outputs = torch.cat((
torch.zeros((n_tokens, slice_offset), device=outputs.device),
outputs,
torch.zeros(
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
device=outputs.device),
),
dim=1)
if add_inputs:
return output_tensor + outputs
else:
return outputs

View File

@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
import functools
import jax
import jax.numpy as jnp
import torch
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from torch.library import impl
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
make_kernel_from_pallas)
# TODO: Tune these
TOKENS_BLOCK = 16
LORA_RANK_BLOCK = 128
DIM_BLOCK_SIZE = 128
def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref,
acc_ref, mask_ref):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
t = pl.program_id(0)
for i in range(bT):
idx = idx_ref[i + bT * t]
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32)
acc_ref[...] += jax.lax.dot_general(
inp_ref[...],
lora_ref[idx, ...], (((1, ), (1, )), ((), ())),
preferred_element_type=jnp.float32) * mask_ref[...]
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
def _():
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
@jax.jit
def _bgmv(
idxs: jax.Array, # (T, ) int32
inputs: jax.Array, # (T, D) model dtype
loras: jax.Array # (N, L, D) model dtype
) -> jax.Array: # (T, L) model dtype
T, D = inputs.shape
N, L, _ = loras.shape
return pl.pallas_call(
kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK),
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK,
D // DIM_BLOCK_SIZE),
in_specs=[
pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (i, k)),
pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (0, j, k)),
],
out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK),
lambda i, j, k, block_idx: (i, j)),
scratch_shapes=[
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32),
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32)
]),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
name="bgmv")(idxs, inputs, loras)
def bgmv_shape_function(idxs, inputs, loras):
T, _ = inputs.shape
_, L, _ = loras.shape
return [((T, L), inputs.dtype)]
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", )
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
inputs = inputs.to(dtype=loras.dtype)
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function)
T, _ = inputs.shape
_, L, D = loras.shape
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
L1 = L
if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0:
L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK
D1 = D
if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0:
D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE
T1 = T
if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0:
T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK
if D1 != D or L1 != L:
loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0))
if D1 != D or T1 != T:
inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T))
if T1 != T:
idxs = torch.nn.functional.pad(idxs, ((0, T1 - T)))
return kernel(idxs, inputs, loras)[:T, :L]
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)

View File

@ -48,7 +48,7 @@ class PunicaWrapperABC(ABC):
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> None:
) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
"""
@ -66,7 +66,7 @@ class PunicaWrapperABC(ABC):
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
) -> Optional[torch.Tensor]:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
"""
@ -80,7 +80,7 @@ class PunicaWrapperABC(ABC):
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
) -> Optional[torch.Tensor]:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
and this layer only requires the expand operation.
@ -98,7 +98,7 @@ class PunicaWrapperABC(ABC):
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
**kwargs) -> Optional[torch.Tensor]:
"""
Applicable to linear-related lora.
"""
@ -114,7 +114,7 @@ class PunicaWrapperABC(ABC):
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
**kwargs) -> Optional[torch.Tensor]:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
"""
@ -207,7 +207,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
def _update_prefill_metadata(self,
token_lora_tensor: torch.Tensor) -> None:
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, token_nums,
@ -334,7 +335,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
long_lora_context)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metada(self.token_lora_indices)
self._update_prefill_metadata(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
@ -342,7 +343,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
@abstractmethod
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs) -> None:
scale: float, **kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
@ -369,7 +370,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> None:
**kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
@ -401,7 +402,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> None:
**kwargs) -> Optional[torch.Tensor]:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
@ -428,7 +429,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
**kwargs) -> Optional[torch.Tensor]:
"""
Applicable to linear-related lora.
@ -463,7 +464,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
**kwargs) -> Optional[torch.Tensor]:
"""
Applies lora specifically for LogitsProcessorWithLoRA.

View File

@ -0,0 +1,325 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from .punica_base import PunicaWrapperBase
class PunicaWrapperTPU(PunicaWrapperBase):
"""
PunicaWrapperTPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
# isn't supported by the TPU. So convert those tensors to int32.
# Not all of them are used by the TPU so only convert the useful ones.
self._token_lora_indices = self._token_lora_indices.to(
dtype=torch.int32)
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
self._sampler_indices_padded = self._sampler_indices_padded.to(
dtype=torch.int32)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
return self._embeddings_indices[:]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
if self.no_lora:
return y
return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x),
scale)
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
add_inputs: bool):
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x),
add_inputs)
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
y_total_size: int, add_inputs: bool) -> torch.Tensor:
return bgmv_expand_slice(x, w_t_all, y,
self._get_token_lora_indices(x), y_offset,
y_slice_size, add_inputs)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
y_s = y[slice_idx]
lora_s = lora_a_stacked[slice_idx]
y_s = self.shrink(y_s, x, lora_s, scale)
y[slice_idx, :, :] = y_s # type: ignore[index]
return y
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> torch.Tensor:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
if lora_bias_stacked is not None:
y = self._apply_bias(self._get_token_lora_indices(y), y,
output_slices, lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
y_total_size=sum(output_slices),
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
return y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> torch.Tensor:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only needs the expand op
return self.expand(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> torch.Tensor:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self._get_token_lora_indices(y), y,
output_slices, lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
T = x.size(0)
buffer = torch.zeros(
(len(output_slices), T, r),
dtype=torch.float32,
device=x.device,
)
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
return self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
if self.no_lora:
return y
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices,
scale)
y = bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
return y.view_as(y_org)
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias = torch.where(indices[:, None] == -1, 0, bias)
bias = F.pad(bias, (offset_left, output.shape[1] -
(offset_left + slice), 0, 0))
output += bias
offset_left += slice
return output.view_as(org_output)
def _update_prefill_metadata(self,
token_lora_tensor: torch.Tensor) -> None:
self.batch_size = 1
self._lora_indices_per_batch[:self.batch_size].copy_(
token_lora_tensor[:self.batch_size])
# TODO: .item() is extremely inefficient on TPU, so find a way around it
self.no_lora = torch.all(token_lora_tensor == -1).item()

View File

@ -125,11 +125,13 @@ def convert_mapping(
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1,
embeddings_indices)
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = torch.where(sampler_indices_padded == -1,
max_loras - 1, sampler_indices_padded)
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))

View File

@ -332,6 +332,27 @@ class Platform:
"""
raise NotImplementedError
@classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
"""
Return the platform specific values for (-inf, inf)
"""
return float("-inf"), float("inf")
@classmethod
def can_update_inplace(cls) -> bool:
"""
Checks if the platform allows inplace memory updates
"""
return True
@classmethod
def get_lora_vocab_padding_size(cls) -> int:
"""
Returns how much padding the LoRA logits need for kernels
"""
return 256
@classmethod
def get_device_communicator_cls(cls) -> str:
"""

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
from tpu_info import device
@ -67,6 +67,22 @@ class TpuPlatform(Platform):
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return not envs.VLLM_USE_V1
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
@classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
return torch.finfo(dtype).min, torch.finfo(dtype).max
@classmethod
def can_update_inplace(cls):
return False
@classmethod
def get_lora_vocab_padding_size(cls) -> int:
return 1
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@ -39,6 +39,7 @@ from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs
@ -90,7 +91,7 @@ MIN_NUM_SEQS = 8
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
class TPUModelRunner:
class TPUModelRunner(LoRAModelRunnerMixin):
def __init__(
self,
@ -568,6 +569,17 @@ class TPUModelRunner:
self.device)
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
num_scheduled_tokens_per_req
) # Copying to avoid accidental state corruption bugs
padded_num_scheduled_tokens_per_req[-1] += \
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
self.set_active_loras(self.input_batch,
padded_num_scheduled_tokens_per_req)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
@ -907,6 +919,11 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
# Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step()
@ -970,7 +987,10 @@ class TPUModelRunner:
for layer_name in layer_names
}
with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0):
with self.maybe_dummy_run_with_lora(
self.lora_config,
np.array([num_tokens], dtype=np.int32)), set_forward_context(
per_layer_attn_metadata, self.vllm_config, 0):
out = self.model(input_ids=input_ids,
positions=position_ids,
inputs_embeds=inputs_embeds)

View File

@ -15,6 +15,7 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
@ -82,6 +83,10 @@ class TPUWorker:
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V1 TPU backend doesn't support LoRA serving")
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
@ -211,6 +216,9 @@ class TPUWorker:
else:
xp.stop_trace()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def load_model(self) -> None:
self.model_runner.load_model()

View File

@ -54,6 +54,10 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V0 TPU backend doesn't support LoRA serving")
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)