mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend (#14238)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@ -50,6 +50,9 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
|
||||
&& echo TEST_12 \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
|
||||
# Disable the TPU LoRA tests until the feature is activated
|
||||
# && echo TEST_13 \
|
||||
# && pytest -s -v /workspace/vllm/tests/tpu/lora/" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
|
@ -47,7 +47,7 @@ def dist_init():
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
backend = "nccl"
|
||||
if current_platform.is_cpu():
|
||||
if current_platform.is_cpu() or current_platform.is_tpu():
|
||||
backend = "gloo"
|
||||
|
||||
init_distributed_environment(world_size=1,
|
||||
|
0
tests/tpu/lora/__init__.py
Normal file
0
tests/tpu/lora/__init__.py
Normal file
124
tests/tpu/lora/test_lora.py
Normal file
124
tests/tpu/lora/test_lora.py
Normal file
@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# This file contains tests to ensure that LoRA works correctly on the TPU
|
||||
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
|
||||
# for this. The adapters are:
|
||||
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
|
||||
# from 1 to 4.
|
||||
|
||||
# These adapters are trained using a standard huggingface peft training script,
|
||||
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
|
||||
# 100 training iterations with a training batch size of 100.
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
|
||||
for all tests in this file
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
yield
|
||||
|
||||
|
||||
def setup_vllm(num_loras: int) -> vllm.LLM:
|
||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||
num_scheduler_steps=1,
|
||||
max_model_len=256,
|
||||
max_seq_len_to_capture=256,
|
||||
max_num_seqs=8,
|
||||
enable_lora=True,
|
||||
max_loras=num_loras,
|
||||
max_lora_rank=8)
|
||||
|
||||
|
||||
def test_single_lora():
|
||||
"""
|
||||
This test ensures we can run a single LoRA adapter on the TPU backend.
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=1.
|
||||
"""
|
||||
|
||||
llm = setup_vllm(1)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
lora_request = LoRARequest(
|
||||
"lora_adapter_1", 1,
|
||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
|
||||
output = llm.generate(prompt,
|
||||
sampling_params=vllm.SamplingParams(max_tokens=256,
|
||||
temperature=0),
|
||||
lora_request=lora_request)[0].outputs[0].text
|
||||
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(answer) == 1
|
||||
|
||||
|
||||
def test_lora_hotswapping():
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, even
|
||||
if we only have space to store 1.
|
||||
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
||||
"""
|
||||
|
||||
lora_name_template = \
|
||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
||||
lora_requests = [
|
||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(1)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
for i, req in enumerate(lora_requests):
|
||||
output = llm.generate(prompt,
|
||||
sampling_params=vllm.SamplingParams(
|
||||
max_tokens=256, temperature=0),
|
||||
lora_request=req)[0].outputs[0].text
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(answer) == i + 1
|
||||
|
||||
|
||||
def test_multi_lora():
|
||||
"""
|
||||
This test ensures we can run multiple LoRA adapters on the TPU backend, when
|
||||
we have enough space to store all of them.
|
||||
|
||||
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
||||
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
||||
"""
|
||||
lora_name_template = \
|
||||
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
||||
lora_requests = [
|
||||
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
||||
for i in range(1, 5)
|
||||
]
|
||||
|
||||
llm = setup_vllm(4)
|
||||
|
||||
prompt = "What is 1+1? \n"
|
||||
|
||||
for i, req in enumerate(lora_requests):
|
||||
output = llm.generate(prompt,
|
||||
sampling_params=vllm.SamplingParams(
|
||||
max_tokens=256, temperature=0),
|
||||
lora_request=req)[0].outputs[0].text
|
||||
|
||||
answer = output.strip()[0]
|
||||
|
||||
assert answer.isdigit()
|
||||
assert int(output.strip()[0]) == i + 1
|
73
tests/tpu/lora/test_pallas_kernels.py
Normal file
73
tests/tpu/lora/test_pallas_kernels.py
Normal file
@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Required to register the custom ops
|
||||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
|
||||
|
||||
N_TOKENS = [16, 1024, 4096]
|
||||
HIDDEN_SIZES = [1024, 2048, 4096]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
NUM_LORA = [1, 4, 16]
|
||||
RANKS = [32, 256, 512]
|
||||
|
||||
|
||||
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
|
||||
"""
|
||||
Inputs: (All integers)
|
||||
T: Total number of tokens
|
||||
D: Input dim
|
||||
L: LoRA Dim
|
||||
N: N LoRAs
|
||||
|
||||
Outputs:
|
||||
inputs: torch.Tensor - shape (T, D)
|
||||
loras: torch.Tensor - shape (N, 1, L, D)
|
||||
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
|
||||
|
||||
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
inputs = torch.randn((T, D), device="xla", dtype=dtype)
|
||||
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
|
||||
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
|
||||
|
||||
ref_output = ref_bgmv(inputs, loras, idxs)
|
||||
return inputs, loras, idxs, ref_output
|
||||
|
||||
|
||||
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
|
||||
selected_loras = loras[idxs]
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(axis=1)
|
||||
|
||||
batch_size, output_size, input_size = selected_loras.shape
|
||||
return (selected_loras @ inputs.reshape(
|
||||
(batch_size, input_size, 1))).reshape((batch_size, output_size))
|
||||
|
||||
|
||||
# Parameterize tests with various shapes and dtypes
|
||||
@pytest.mark.parametrize("T", N_TOKENS)
|
||||
@pytest.mark.parametrize("D", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("L", RANKS)
|
||||
@pytest.mark.parametrize("N", NUM_LORA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", [0])
|
||||
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
|
||||
if op_type == "expand":
|
||||
D, L = L, D
|
||||
|
||||
inputs, loras, idxs, ref_output = generate_test_data(
|
||||
T, D, L, N, seed, dtype)
|
||||
|
||||
# Run bgmv
|
||||
output = torch.ops.xla.bgmv(inputs, loras, idxs)
|
||||
|
||||
# Make sure we have no NaNs
|
||||
assert not torch.any(torch.isnan(output))
|
||||
|
||||
# Compare with reference output
|
||||
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)
|
@ -2694,8 +2694,8 @@ class LoRAConfig:
|
||||
lora_extra_vocab_size: int = 256
|
||||
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
|
||||
(added to the base model vocabulary)."""
|
||||
# This is a constant.
|
||||
lora_vocab_padding_size: ClassVar[int] = 256
|
||||
lora_vocab_padding_size: ClassVar[int] = current_platform\
|
||||
.get_lora_vocab_padding_size()
|
||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
|
||||
"""Specify multiple scaling factors (which can be different from base model
|
||||
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
|
||||
@ -2723,6 +2723,7 @@ class LoRAConfig:
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
factors.append(self.lora_extra_vocab_size)
|
||||
factors.append(self.lora_vocab_padding_size)
|
||||
factors.append(self.long_lora_scaling_factors)
|
||||
factors.append(self.bias_enabled)
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
|
@ -16,6 +16,7 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
|
||||
shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink(
|
||||
buffers, x, layer.lora_a_stacked, 1.0)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
buffers = shrunk_buffers
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
layer.punica_wrapper.add_expand(output,
|
||||
buffers,
|
||||
layer.lora_b_stacked,
|
||||
layer.lora_bias_stacked,
|
||||
layer.output_slices,
|
||||
offset_start=0,
|
||||
add_input=True)
|
||||
|
||||
lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffers,
|
||||
layer.lora_b_stacked,
|
||||
layer.lora_bias_stacked,
|
||||
layer.output_slices,
|
||||
offset_start=0,
|
||||
add_input=True)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
# now have column partitioned and packed output
|
||||
@ -292,7 +303,11 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink(
|
||||
buffer, x, self.lora_a_stacked, 1.0)
|
||||
if not current_platform.can_update_inplace():
|
||||
buffer = shrunk_buffer
|
||||
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
@ -304,7 +319,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
# NOTE offset are based on the rank.
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
offset_start = self.tp_rank * shard_size
|
||||
self.punica_wrapper.add_expand(
|
||||
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
@ -313,6 +328,10 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
offset_start=offset_start,
|
||||
add_input=True,
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
|
@ -261,10 +261,17 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
full_lora_a_embeddings.shape[1],
|
||||
-1,
|
||||
)
|
||||
self.punica_wrapper.add_lora_embedding(full_output,
|
||||
full_lora_a_embeddings,
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
|
||||
lora_output: Optional[
|
||||
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
|
||||
full_output,
|
||||
full_lora_a_embeddings,
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
full_output = lora_output
|
||||
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
@classmethod
|
||||
@ -410,10 +417,13 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
output = output.flatten(0, 1)
|
||||
x = x.flatten(0, 1)
|
||||
|
||||
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.lora_bias_stacked, 1.0,
|
||||
self.output_slices)
|
||||
lora_output: Optional[
|
||||
torch.Tensor] = self.punica_wrapper.add_lora_linear(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked,
|
||||
self.lora_bias_stacked, 1.0, self.output_slices)
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
@ -1133,15 +1143,23 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
torch.matmul(self.embeddings_tensors,
|
||||
hidden_states.T,
|
||||
out=lora_logits[:-1])
|
||||
lora_logits[-1] = float("-inf")
|
||||
|
||||
neg_inf, pos_inf = current_platform.get_infinity_values(
|
||||
lora_logits.dtype)
|
||||
|
||||
lora_logits[-1] = neg_inf
|
||||
lora_logits = lora_logits.mT
|
||||
indices_padded = self.punica_wrapper.sampler_indices_padded
|
||||
|
||||
if current_platform.is_tpu():
|
||||
indices_padded = indices_padded[:logits.size(0)]
|
||||
|
||||
lora_logits = (lora_logits.reshape(
|
||||
lora_logits.shape[0] * lora_logits.shape[1],
|
||||
lora_logits.shape[2],
|
||||
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
|
||||
posinf=float("inf"),
|
||||
neginf=float("-inf")))
|
||||
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
|
||||
posinf=pos_inf,
|
||||
neginf=neg_inf))
|
||||
|
||||
# HPU needs special handling to prune out dummy samples.
|
||||
if current_platform.is_hpu():
|
||||
@ -1151,10 +1169,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||
lora_logits.shape[1]] = lora_logits
|
||||
|
||||
# LogitsProcessorWithLoRA always using bgmv
|
||||
self.punica_wrapper.add_lora_logits(logits, hidden_states,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
lora_output: Optional[
|
||||
torch.Tensor] = self.punica_wrapper.add_lora_logits(
|
||||
logits, hidden_states, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
logits = lora_output
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.base_layer.vocab_size]
|
||||
|
6
vllm/lora/ops/xla_ops/__init__.py
Normal file
6
vllm/lora/ops/xla_ops/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink)
|
||||
|
||||
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
|
106
vllm/lora/ops/xla_ops/lora_ops.py
Normal file
106
vllm/lora/ops/xla_ops/lora_ops.py
Normal file
@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
|
||||
# Required to register the custom ops
|
||||
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
n_tokens = outputs.size(0)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
outputs = torch.cat(
|
||||
(outputs,
|
||||
torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
|
||||
device=outputs.device)),
|
||||
dim=1)
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs[:limit, :]
|
||||
else:
|
||||
return outputs[:limit, :]
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
scaling (float, optional): Scalar multiplier applied to the output.
|
||||
"""
|
||||
|
||||
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights,
|
||||
lora_indices_tensor)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
n_tokens = outputs.size(0)
|
||||
|
||||
outputs = torch.cat((
|
||||
torch.zeros((n_tokens, slice_offset), device=outputs.device),
|
||||
outputs,
|
||||
torch.zeros(
|
||||
(n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
|
||||
device=outputs.device),
|
||||
),
|
||||
dim=1)
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs
|
||||
else:
|
||||
return outputs
|
133
vllm/lora/ops/xla_ops/pallas.py
Normal file
133
vllm/lora/ops/xla_ops/pallas.py
Normal file
@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import functools
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
from torch.library import impl
|
||||
from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard,
|
||||
make_kernel_from_pallas)
|
||||
|
||||
# TODO: Tune these
|
||||
TOKENS_BLOCK = 16
|
||||
LORA_RANK_BLOCK = 128
|
||||
DIM_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref,
|
||||
acc_ref, mask_ref):
|
||||
|
||||
@pl.when(pl.program_id(2) == 0)
|
||||
def _():
|
||||
acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32)
|
||||
|
||||
t = pl.program_id(0)
|
||||
|
||||
for i in range(bT):
|
||||
idx = idx_ref[i + bT * t]
|
||||
mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32)
|
||||
mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32)
|
||||
|
||||
acc_ref[...] += jax.lax.dot_general(
|
||||
inp_ref[...],
|
||||
lora_ref[idx, ...], (((1, ), (1, )), ((), ())),
|
||||
preferred_element_type=jnp.float32) * mask_ref[...]
|
||||
|
||||
@pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
|
||||
def _():
|
||||
out_ref[...] = acc_ref[...].astype(out_ref.dtype)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _bgmv(
|
||||
idxs: jax.Array, # (T, ) int32
|
||||
inputs: jax.Array, # (T, D) model dtype
|
||||
loras: jax.Array # (N, L, D) model dtype
|
||||
) -> jax.Array: # (T, L) model dtype
|
||||
T, D = inputs.shape
|
||||
N, L, _ = loras.shape
|
||||
|
||||
return pl.pallas_call(
|
||||
kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK),
|
||||
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK,
|
||||
D // DIM_BLOCK_SIZE),
|
||||
in_specs=[
|
||||
pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE),
|
||||
lambda i, j, k, block_idx: (i, k)),
|
||||
pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE),
|
||||
lambda i, j, k, block_idx: (0, j, k)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK),
|
||||
lambda i, j, k, block_idx: (i, j)),
|
||||
scratch_shapes=[
|
||||
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32),
|
||||
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32)
|
||||
]),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
name="bgmv")(idxs, inputs, loras)
|
||||
|
||||
|
||||
def bgmv_shape_function(idxs, inputs, loras):
|
||||
T, _ = inputs.shape
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return [((T, L), inputs.dtype)]
|
||||
|
||||
|
||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", )
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "XLA")
|
||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
inputs = inputs.to(dtype=loras.dtype)
|
||||
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
jax_import_guard()
|
||||
kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function)
|
||||
|
||||
T, _ = inputs.shape
|
||||
_, L, D = loras.shape
|
||||
|
||||
# Pad the loras' rank if it's too low. This is to allow it to fit in a TPU
|
||||
# register. This has to happen in pytorch, doing it in Jax will lead to NaNs
|
||||
L1 = L
|
||||
if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0:
|
||||
L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK
|
||||
|
||||
D1 = D
|
||||
if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0:
|
||||
D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE
|
||||
|
||||
T1 = T
|
||||
if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0:
|
||||
T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK
|
||||
|
||||
if D1 != D or L1 != L:
|
||||
loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0))
|
||||
if D1 != D or T1 != T:
|
||||
inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T))
|
||||
if T1 != T:
|
||||
idxs = torch.nn.functional.pad(idxs, ((0, T1 - T)))
|
||||
|
||||
return kernel(idxs, inputs, loras)[:T, :L]
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
|
||||
idxs: torch.IntTensor):
|
||||
T, _ = inputs.shape
|
||||
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return torch.empty((T, L), device=inputs.device)
|
@ -48,7 +48,7 @@ class PunicaWrapperABC(ABC):
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
"""
|
||||
@ -66,7 +66,7 @@ class PunicaWrapperABC(ABC):
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
"""
|
||||
@ -80,7 +80,7 @@ class PunicaWrapperABC(ABC):
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
|
||||
and this layer only requires the expand operation.
|
||||
@ -98,7 +98,7 @@ class PunicaWrapperABC(ABC):
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
**kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
"""
|
||||
@ -114,7 +114,7 @@ class PunicaWrapperABC(ABC):
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
**kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
"""
|
||||
@ -207,7 +207,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
self._long_lora_indices.zero_()
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
|
||||
def _update_prefill_metadata(self,
|
||||
token_lora_tensor: torch.Tensor) -> None:
|
||||
|
||||
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||
batch_size, max_length, token_nums,
|
||||
@ -334,7 +335,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
long_lora_context)
|
||||
if mapping.is_prefill:
|
||||
# Update metadata required for prefill-related operators.
|
||||
self._update_prefill_metada(self.token_lora_indices)
|
||||
self._update_prefill_metadata(self.token_lora_indices)
|
||||
self.is_prefill = True
|
||||
else:
|
||||
self.is_prefill = False
|
||||
@ -342,7 +343,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
@abstractmethod
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs) -> None:
|
||||
scale: float, **kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
|
||||
@ -369,7 +370,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs) -> None:
|
||||
**kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
@ -401,7 +402,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs) -> None:
|
||||
**kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
and this layer only requires the expand operation.
|
||||
@ -428,7 +429,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
**kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
@ -463,7 +464,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
**kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
|
325
vllm/lora/punica_wrapper/punica_tpu.py
Normal file
325
vllm/lora/punica_wrapper/punica_tpu.py
Normal file
@ -0,0 +1,325 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperTPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
|
||||
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
|
||||
# isn't supported by the TPU. So convert those tensors to int32.
|
||||
# Not all of them are used by the TPU so only convert the useful ones.
|
||||
self._token_lora_indices = self._token_lora_indices.to(
|
||||
dtype=torch.int32)
|
||||
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
|
||||
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
||||
dtype=torch.int32)
|
||||
|
||||
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
||||
|
||||
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
|
||||
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
|
||||
|
||||
@property
|
||||
def embeddings_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for lora embeddings,
|
||||
specifically for VocabParallelEmbeddingWithLoRA.
|
||||
"""
|
||||
return self._embeddings_indices[:]
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
return self._sampler_indices_padded[:]
|
||||
|
||||
def shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
if self.no_lora:
|
||||
return y
|
||||
return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x),
|
||||
scale)
|
||||
|
||||
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
|
||||
add_inputs: bool):
|
||||
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x),
|
||||
add_inputs)
|
||||
|
||||
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
|
||||
y_total_size: int, add_inputs: bool) -> torch.Tensor:
|
||||
return bgmv_expand_slice(x, w_t_all, y,
|
||||
self._get_token_lora_indices(x), y_offset,
|
||||
y_slice_size, add_inputs)
|
||||
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
y_s = y[slice_idx]
|
||||
lora_s = lora_a_stacked[slice_idx]
|
||||
y_s = self.shrink(y_s, x, lora_s, scale)
|
||||
y[slice_idx, :, :] = y_s # type: ignore[index]
|
||||
return y
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = 0
|
||||
|
||||
if lora_bias_stacked is not None:
|
||||
y = self._apply_bias(self._get_token_lora_indices(y), y,
|
||||
output_slices, lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
y = self.expand_slice(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
y_total_size=sum(output_slices),
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
return y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
# Embedding layer only needs the expand op
|
||||
return self.expand(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)+lora_bias_stacked[i]
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will not be changed in-place.
|
||||
x (torch.Tensor): Input tensor (T, E)
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (Tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
if lora_bias_stacked is not None:
|
||||
assert len(lora_bias_stacked) == len(output_slices)
|
||||
y = self._apply_bias(self._get_token_lora_indices(y), y,
|
||||
output_slices, lora_bias_stacked)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
T = x.size(0)
|
||||
buffer = torch.zeros(
|
||||
(len(output_slices), T, r),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
return self.add_expand(y,
|
||||
buffer,
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
if self.no_lora:
|
||||
return y
|
||||
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices,
|
||||
scale)
|
||||
y = bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
add_inputs=True)
|
||||
return y.view_as(y_org)
|
||||
|
||||
def _apply_bias(
|
||||
self,
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_slices: Tuple[int, ...],
|
||||
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
|
||||
):
|
||||
"""Applies bias to output
|
||||
|
||||
Input shapes:
|
||||
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||
output_slices: n-1 element tuple of (slice_size...),
|
||||
where n is number of slices
|
||||
"""
|
||||
org_output = output
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
|
||||
offset_left = 0
|
||||
for slice_idx, slice in enumerate(output_slices):
|
||||
bias = lora_bias_stacked[slice_idx]
|
||||
if bias is not None:
|
||||
bias = bias.view(-1, bias.shape[-1])
|
||||
bias = bias[indices]
|
||||
bias = torch.where(indices[:, None] == -1, 0, bias)
|
||||
|
||||
bias = F.pad(bias, (offset_left, output.shape[1] -
|
||||
(offset_left + slice), 0, 0))
|
||||
|
||||
output += bias
|
||||
offset_left += slice
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
def _update_prefill_metadata(self,
|
||||
token_lora_tensor: torch.Tensor) -> None:
|
||||
self.batch_size = 1
|
||||
self._lora_indices_per_batch[:self.batch_size].copy_(
|
||||
token_lora_tensor[:self.batch_size])
|
||||
# TODO: .item() is extremely inefficient on TPU, so find a way around it
|
||||
self.no_lora = torch.all(token_lora_tensor == -1).item()
|
@ -125,11 +125,13 @@ def convert_mapping(
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size),
|
||||
])
|
||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||
embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1,
|
||||
embeddings_indices)
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping_tensor
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = torch.where(sampler_indices_padded == -1,
|
||||
max_loras - 1, sampler_indices_padded)
|
||||
sampler_indices_padded = torch.arange(
|
||||
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
||||
sampler_indices_padded * len(sampler_indices_padded))
|
||||
|
@ -332,6 +332,27 @@ class Platform:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
|
||||
"""
|
||||
Return the platform specific values for (-inf, inf)
|
||||
"""
|
||||
return float("-inf"), float("inf")
|
||||
|
||||
@classmethod
|
||||
def can_update_inplace(cls) -> bool:
|
||||
"""
|
||||
Checks if the platform allows inplace memory updates
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lora_vocab_padding_size(cls) -> int:
|
||||
"""
|
||||
Returns how much padding the LoRA logits need for kernels
|
||||
"""
|
||||
return 256
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
@ -67,6 +67,22 @@ class TpuPlatform(Platform):
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return not envs.VLLM_USE_V1
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
|
||||
return torch.finfo(dtype).min, torch.finfo(dtype).max
|
||||
|
||||
@classmethod
|
||||
def can_update_inplace(cls):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lora_vocab_padding_size(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
@ -39,6 +39,7 @@ from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
@ -90,7 +91,7 @@ MIN_NUM_SEQS = 8
|
||||
# The dummy_run should be comprehensive, ensuring all potential input shapes and
|
||||
# branch predictions are included as subgraph inputs to facilitate
|
||||
# pre-compilation.
|
||||
class TPUModelRunner:
|
||||
class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -568,6 +569,17 @@ class TPUModelRunner:
|
||||
self.device)
|
||||
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
|
||||
|
||||
if self.lora_config is not None:
|
||||
# We need to respect padding when activating LoRA adapters
|
||||
padded_num_scheduled_tokens_per_req = np.copy(
|
||||
num_scheduled_tokens_per_req
|
||||
) # Copying to avoid accidental state corruption bugs
|
||||
padded_num_scheduled_tokens_per_req[-1] += \
|
||||
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
|
||||
|
||||
self.set_active_loras(self.input_batch,
|
||||
padded_num_scheduled_tokens_per_req)
|
||||
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
@ -907,6 +919,11 @@ class TPUModelRunner:
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
if self.lora_config is not None:
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
|
||||
# Sync all pending XLA execution during model initialization and weight
|
||||
# loading.
|
||||
xm.mark_step()
|
||||
@ -970,7 +987,10 @@ class TPUModelRunner:
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0):
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
self.lora_config,
|
||||
np.array([num_tokens], dtype=np.int32)), set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, 0):
|
||||
out = self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
@ -15,6 +15,7 @@ from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -82,6 +83,10 @@ class TPUWorker:
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
"The V1 TPU backend doesn't support LoRA serving")
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
@ -211,6 +216,9 @@ class TPUWorker:
|
||||
else:
|
||||
xp.stop_trace()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
|
@ -54,6 +54,10 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
"The V0 TPU backend doesn't support LoRA serving")
|
||||
|
||||
def init_device(self) -> None:
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
torch.set_grad_enabled(False)
|
||||
|
Reference in New Issue
Block a user