Compare commits

...

9 Commits

Author SHA1 Message Date
966f933ee1 [Bugfix] Fix LoRA extra vocab size (#15047)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-03-18 10:51:10 -07:00
1a504aff6c [Bugfix] Fix broken CPU quantization due to triton import (#15038)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-03-18 10:51:10 -07:00
01ca85bbd8 [MODEL] Add support for Zamba2 models (#13185)
Signed-off-by: Yury Tokpanov <yury@zyphra.com>
Signed-off-by: Quentin Anthony <qganthony@yahoo.com>
Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2025-03-18 10:51:10 -07:00
d82b9487ea [Bugfix] Register serializers for V0 MQ Engine (#15009)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-03-18 10:51:10 -07:00
be13281d4b [Bugfix] Loosen type check to avoid errors in V1 (#15021)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-03-18 10:51:10 -07:00
54e084f7fb [Bugfix] torchrun compatibility (#14899)
Signed-off-by: hiyouga <hiyouga@buaa.edu.cn>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
2025-03-18 10:51:10 -07:00
9e8f089d08 [Kernels] LoRA - Retire SGMV and BGMV Kernels (#14685)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2025-03-18 10:51:10 -07:00
16e9064f84 [V1] Guard Against Main Thread Usage (#14972)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-17 13:23:17 -07:00
5ac1a8e6e4 [Bugfix] Fix interface for Olmo2 on V1 (#14976)
Signed-off-by: Roger Wang <ywang@roblox.com>
2025-03-17 11:41:43 -07:00
43 changed files with 1410 additions and 2171 deletions

View File

@ -17,13 +17,8 @@ from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES from weight_shapes import WEIGHT_SHAPES
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@ -167,69 +162,25 @@ class OpType(Enum):
""" """
LoRA Ops to benchmark and its properties. LoRA Ops to benchmark and its properties.
""" """
SGMV_SHRINK = auto() LORA_SHRINK = auto()
BGMV_SHRINK = auto() LORA_EXPAND = auto()
SGMV_EXPAND = auto()
BGMV_EXPAND = auto()
BGMV_EXPAND_SLICE = auto()
V1_SHRINK = auto()
V1_EXPAND = auto()
@staticmethod @staticmethod
def from_str(s: str) -> "OpType": def from_str(s: str) -> "OpType":
if s.lower() == 'sgmv_shrink': if s.lower() == "lora_shrink":
return OpType.SGMV_SHRINK return OpType.LORA_SHRINK
if s.lower() == 'sgmv_expand': if s.lower() == "lora_expand":
return OpType.SGMV_EXPAND return OpType.LORA_EXPAND
if s.lower() == 'bgmv_shrink':
return OpType.BGMV_SHRINK
if s.lower() == 'bgmv_expand':
return OpType.BGMV_EXPAND
if s.lower() == "bgmv_expand_slice":
return OpType.BGMV_EXPAND_SLICE
if s.lower() == "v1_shrink":
return OpType.V1_SHRINK
if s.lower() == "v1_expand":
return OpType.V1_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType") raise ValueError(f"Unrecognized str {s} to convert to OpType")
def is_shrink_fn(self) -> bool: def is_shrink_fn(self) -> bool:
return self in [ return self in [OpType.LORA_SHRINK]
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
]
def is_expand_fn(self) -> bool: def is_expand_fn(self) -> bool:
return self in [ return self in [OpType.LORA_EXPAND]
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
]
def is_prefill_op(self) -> bool:
return self in [
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
OpType.V1_EXPAND
]
def is_decode_op(self) -> bool:
return self in [
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
OpType.V1_SHRINK, OpType.V1_EXPAND
]
def is_expand_slice_fn(self) -> bool:
return self in [OpType.BGMV_EXPAND_SLICE]
def num_slices(self) -> list[int]: def num_slices(self) -> list[int]:
if self in [ return [1, 2, 3]
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
OpType.V1_EXPAND
]:
# SGMV kernels and v1 kernels supports slices
return [1, 2, 3]
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
return [1]
if self in [OpType.BGMV_EXPAND_SLICE]:
return [2, 3]
raise ValueError(f"Unrecognized OpType {self}")
def mkn(self, batch_size: int, seq_length: int, hidden_size: int, def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int) -> tuple[int, int, int]: lora_rank: int) -> tuple[int, int, int]:
@ -239,7 +190,7 @@ class OpType(Enum):
k = hidden_size k = hidden_size
n = lora_rank n = lora_rank
else: else:
assert self.is_expand_fn() or self.is_expand_slice_fn() assert self.is_expand_fn()
m = num_tokens m = num_tokens
k = lora_rank k = lora_rank
n = hidden_size n = hidden_size
@ -254,7 +205,7 @@ class OpType(Enum):
if self.is_shrink_fn(): if self.is_shrink_fn():
return op_dtype, op_dtype, torch.float32 return op_dtype, op_dtype, torch.float32
else: else:
assert self.is_expand_fn() or self.is_expand_slice_fn() assert self.is_expand_fn()
return torch.float32, op_dtype, op_dtype return torch.float32, op_dtype, op_dtype
def matmul_shapes( def matmul_shapes(
@ -268,43 +219,19 @@ class OpType(Enum):
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
b_shape = (num_loras, n, k) # col-major b_shape = (num_loras, n, k) # col-major
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]: if self in [OpType.LORA_SHRINK]:
# SGMV shrink and V1 shrink kernels support num_slices inherently # LoRA shrink kernels support num_slices inherently in the kernel.
# in the kernel.
return ((m, k), b_shape, (num_slices, m, n)) return ((m, k), b_shape, (num_slices, m, n))
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]: if self in [OpType.LORA_EXPAND]:
# SGMV expand and V1 expand kernels support num_slices inherently # LoRA expand kernels support num_slices inherently in the kernel
# in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices)) return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self == OpType.BGMV_SHRINK:
return ((m, k), b_shape, (m, n))
if self == OpType.BGMV_EXPAND:
return ((m, k), b_shape, (m, n))
if self == OpType.BGMV_EXPAND_SLICE:
return ((num_slices, m, k), b_shape, (m, n * num_slices))
raise ValueError(f"Unrecognized op_type {self}") raise ValueError(f"Unrecognized op_type {self}")
def bench_fn(self) -> Callable: def bench_fn(self) -> Callable:
if self == OpType.LORA_SHRINK:
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]): return lora_shrink
for x in kwargs_list: if self == OpType.LORA_EXPAND:
bgmv_expand_slice(**x) return lora_expand
if self == OpType.SGMV_SHRINK:
return sgmv_shrink
if self == OpType.SGMV_EXPAND:
return sgmv_expand
if self == OpType.BGMV_SHRINK:
return bgmv_shrink
if self == OpType.BGMV_EXPAND:
return bgmv_expand
if self == OpType.BGMV_EXPAND_SLICE:
return emulate_bgmv_expand_slice
if self == OpType.V1_SHRINK:
return v1_shrink
if self == OpType.V1_EXPAND:
return v1_expand
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
@ -318,34 +245,13 @@ class OpType(Enum):
""" """
w_dtype = lora_weights[0].dtype w_dtype = lora_weights[0].dtype
num_slices = len(lora_weights) num_slices = len(lora_weights)
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]: if self in [OpType.LORA_SHRINK]:
for slice_idx in range(num_slices): for slice_idx in range(num_slices):
ref_group_gemm(ref_out=output[slice_idx, :], ref_group_gemm(ref_out=output[slice_idx, :],
input=input, input=input,
lora_weights=lora_weights[slice_idx], lora_weights=lora_weights[slice_idx],
**kwargs) **kwargs)
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]: elif self in [OpType.LORA_EXPAND]:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
ref_group_gemm(
ref_out=output[:, slice_offset:slice_offset + hidden_size],
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
**kwargs)
elif self == OpType.BGMV_SHRINK:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input,
lora_weights=lora_weights[0],
**kwargs)
elif self == OpType.BGMV_EXPAND:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input.clone().to(dtype=w_dtype),
lora_weights=lora_weights[0],
**kwargs)
elif self == OpType.BGMV_EXPAND_SLICE:
hidden_size = lora_weights[0].shape[1] hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices): for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size slice_offset = slice_idx * hidden_size
@ -411,13 +317,11 @@ class BenchmarkTensors:
input: torch.Tensor input: torch.Tensor
lora_weights_lst: list[torch.Tensor] lora_weights_lst: list[torch.Tensor]
output: torch.Tensor output: torch.Tensor
# metadata tensors # LoRA kernel metadata
lora_kernel_meta: LoRAKernelMeta
# Metadata tensors used in testing correctness
seq_lens: torch.Tensor seq_lens: torch.Tensor
seq_start_loc: torch.Tensor
prompt_lora_mapping: torch.Tensor prompt_lora_mapping: torch.Tensor
token_lora_mapping: torch.Tensor
# v1 kernel metadata
v1_kernel_meta: Optional[V1KernelMeta] = None
def io_types(self) -> str: def io_types(self) -> str:
return (f"{dtype_to_str(self.input.dtype)}x" return (f"{dtype_to_str(self.input.dtype)}x"
@ -444,35 +348,29 @@ class BenchmarkTensors:
assert ctx.num_active_loras <= ctx.num_loras assert ctx.num_active_loras <= ctx.num_loras
total_tokens = ctx.batch_size * ctx.seq_length total_tokens = ctx.batch_size * ctx.seq_length
# Make metadata tensors involved in correctness testing.
# Prepare seq lens tensor # Prepare seq lens tensor
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
(ctx.batch_size, )) (ctx.batch_size, ))
# Prepare seq_start_loc tensor
seq_start_loc_tensor = torch.cumsum(torch.tensor(
[0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
dim=0)
assert total_tokens == seq_len_tensor.sum() assert total_tokens == seq_len_tensor.sum()
# Prepare prompt lora indices tensor # Prepare prompt lora indices tensor
prompt_lora_indices_tensor = make_prompt_lora_mapping( prompt_lora_indices_tensor = make_prompt_lora_mapping(
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
# Prepare token lora indices tensor
# Make LoRAKernelMeta
token_lora_indices_tensor = make_token_lora_mapping( token_lora_indices_tensor = make_token_lora_mapping(
total_tokens, ctx.batch_size, prompt_lora_indices_tensor, total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
seq_len_tensor, "cpu") seq_len_tensor, "cpu")
lora_kernel_meta = LoRAKernelMeta.make(
v1_kernel_meta = None max_loras=ctx.num_loras,
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]: max_num_tokens=token_lora_indices_tensor.size(0),
v1_kernel_meta = V1KernelMeta.make( device="cpu")
max_loras=ctx.num_loras, lora_kernel_meta.prepare_tensors(
max_num_tokens=token_lora_indices_tensor.size(0), token_lora_mapping=token_lora_indices_tensor)
device="cpu")
v1_kernel_meta.prepare_tensors(
token_lora_mapping=token_lora_indices_tensor)
return BenchmarkTensors(input_tensor, lora_weights, output_tensor, return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
seq_len_tensor, seq_start_loc_tensor, lora_kernel_meta, seq_len_tensor,
prompt_lora_indices_tensor, prompt_lora_indices_tensor)
token_lora_indices_tensor, v1_kernel_meta)
def sanity_check(self) -> None: def sanity_check(self) -> None:
""" """
@ -482,9 +380,9 @@ class BenchmarkTensors:
# check metadata tensors # check metadata tensors
assert torch.sum(self.seq_lens) == num_tokens assert torch.sum(self.seq_lens) == num_tokens
num_seqs = self.seq_lens.shape[0] num_seqs = self.seq_lens.shape[0]
assert self.seq_start_loc.shape[0] == num_seqs #assert self.seq_start_loc.shape[0] == num_seqs
assert self.prompt_lora_mapping.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs
assert self.token_lora_mapping.shape[0] == num_tokens assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
def to_device(self, device: str): def to_device(self, device: str):
""" """
@ -499,220 +397,27 @@ class BenchmarkTensors:
self.input = to_device(self.input) self.input = to_device(self.input)
self.output = to_device(self.output) self.output = to_device(self.output)
self.seq_lens = to_device(self.seq_lens) self.seq_lens = to_device(self.seq_lens)
self.seq_start_loc = to_device(self.seq_start_loc)
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping) self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
self.token_lora_mapping = to_device(self.token_lora_mapping)
for i in range(len(self.lora_weights_lst)): for i in range(len(self.lora_weights_lst)):
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
# v1 meta # LoRA meta
if self.v1_kernel_meta: for field_name in LoRAKernelMeta.__dataclass_fields__:
for field_name in V1KernelMeta.__dataclass_fields__: field = getattr(self.lora_kernel_meta, field_name)
field = getattr(self.v1_kernel_meta, field_name) assert isinstance(field, torch.Tensor)
assert isinstance(field, torch.Tensor) setattr(self.lora_kernel_meta, field_name, to_device(field))
setattr(self.v1_kernel_meta, field_name, to_device(field))
def metadata(self) -> tuple[int, int, int]: def metadata(self) -> tuple[int, int, int]:
""" """
Return num_seqs, num_tokens and max_seq_len Return num_seqs, num_tokens and max_seq_len
""" """
num_seqs = self.seq_lens.shape[0] num_seqs = self.seq_lens.shape[0]
num_tokens = self.token_lora_mapping.shape[0] num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
max_seq_len = torch.max(self.seq_lens).item() max_seq_len = torch.max(self.seq_lens).item()
num_slices = len(self.lora_weights_lst) num_slices = len(self.lora_weights_lst)
return num_seqs, num_tokens, max_seq_len, num_slices return num_seqs, num_tokens, max_seq_len, num_slices
def convert_to_sgmv_benchmark_tensors(self): def as_lora_shrink_kwargs(self) -> dict[str, Any]:
"""
For sgmv punica kernels, when consecutive sequences have the
same LoRA ID, we just merge them together.
This happens in punica.py::compute_metadata
"""
# Collapse seq_lens and seq_start_loc
_, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
return_counts=True)
cum_result = torch.cumsum(seq_lens, dim=0)
seq_start_loc = torch.zeros_like(seq_lens)
seq_start_loc[1:].copy_(cum_result[:-1])
# Collapse prompt mapping
prompt_lora_mapping = torch.unique_consecutive(
self.prompt_lora_mapping)
assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"
self.prompt_lora_mapping = prompt_lora_mapping.to(
dtype=self.prompt_lora_mapping.dtype)
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors()
self.sanity_check()
self.to_device(self.input.device)
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert len(lw_shape) == 3
assert lw_shape[2] == hidden_size
lora_rank = lw_shape[1]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert len(o_shape) == 3
assert o_shape == (num_slices, num_tokens, lora_rank)
return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output,
'b_seq_start_loc': self.seq_start_loc,
'seq_len_tensor': self.seq_lens,
'lora_indices_tensor': self.prompt_lora_mapping,
'batches': num_seqs,
'max_seq_length': max_seq_len,
'token_nums': num_tokens,
'scaling': 1.0,
}
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors()
self.sanity_check()
self.to_device(self.input.device)
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[2]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size * num_slices)
return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output,
'b_seq_start_loc': self.seq_start_loc,
'seq_len_tensor': self.seq_lens,
'lora_indices_tensor': self.prompt_lora_mapping,
'batches': num_seqs,
'max_seq_length': max_seq_len,
'token_nums': num_tokens,
'offset_start': 0,
'add_inputs': add_inputs,
}
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
assert len(self.lora_weights_lst) == 1
self.to_device(self.input.device)
_, num_tokens, _, _ = self.metadata()
# Sanity check shapes
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert len(lw_shape) == 3
assert lw_shape[2] == hidden_size
lora_rank = lw_shape[1]
# Expected output shape [num_tokens, lora_rank]
assert len(o_shape) == 2
assert o_shape == (num_tokens, lora_rank)
return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst[0],
'output_tensor': self.output,
'lora_indices_tensor': self.token_lora_mapping,
'scaling': 1.0
}
def as_bgmv_expand_kwargs(self, add_inputs: bool):
assert len(self.lora_weights_lst) == 1
self.to_device(self.input.device)
_, num_tokens, _, _ = self.metadata()
# Sanity check shapes
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, lora_rank]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
lora_rank = i_shape[1]
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape [num_tokens, hidden_size]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size)
return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst[0],
'output_tensor': self.output,
'lora_indices_tensor': self.token_lora_mapping,
'add_inputs': add_inputs
}
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
_, num_tokens, _, num_slices = self.metadata()
# Sanity check shapes
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[2]
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape [num_tokens, hidden_size * num_slices]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size * num_slices)
self.to_device(self.input.device)
kwargs_list = []
for i in range(num_slices):
kwargs_list.append({
'inputs': self.input[i],
'lora_b_weights': self.lora_weights_lst[i],
'output_tensor': self.output,
'lora_indices_tensor': self.token_lora_mapping,
'slice_offset': i * hidden_size,
'slice_size': hidden_size,
'add_inputs': add_inputs,
})
return {'kwargs_list': kwargs_list}
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check() self.sanity_check()
self.to_device(self.input.device) self.to_device(self.input.device)
@ -737,17 +442,16 @@ class BenchmarkTensors:
'inputs': self.input, 'inputs': self.input,
'lora_a_weights': self.lora_weights_lst, 'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output, 'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping, 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids': 'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids, self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora, 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc, 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids, 'lora_ids': self.lora_kernel_meta.active_lora_ids,
'scaling': 1.0, 'scaling': 1.0,
} }
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check() self.sanity_check()
self.to_device(self.input.device) self.to_device(self.input.device)
@ -773,12 +477,12 @@ class BenchmarkTensors:
'inputs': self.input, 'inputs': self.input,
'lora_b_weights': self.lora_weights_lst, 'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output, 'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping, 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids': 'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids, self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora, 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc, 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids, 'lora_ids': self.lora_kernel_meta.active_lora_ids,
'offset_start': 0, 'offset_start': 0,
'add_inputs': add_inputs, 'add_inputs': add_inputs,
} }
@ -791,20 +495,10 @@ class BenchmarkTensors:
else: else:
assert add_inputs is not None assert add_inputs is not None
if op_type == OpType.SGMV_SHRINK: if op_type == OpType.LORA_SHRINK:
return self.as_sgmv_shrink_kwargs() return self.as_lora_shrink_kwargs()
if op_type == OpType.SGMV_EXPAND: if op_type == OpType.LORA_EXPAND:
return self.as_sgmv_expand_kwargs(add_inputs) return self.as_lora_expand_kwargs(add_inputs)
if op_type == OpType.BGMV_SHRINK:
return self.as_bgmv_shrink_kwargs()
if op_type == OpType.BGMV_EXPAND:
return self.as_bgmv_expand_kwargs(add_inputs)
if op_type == OpType.BGMV_EXPAND_SLICE:
return self.as_bgmv_expand_slice_kwargs(add_inputs)
if op_type == OpType.V1_SHRINK:
return self.as_v1_shrink_kwargs()
if op_type == OpType.V1_EXPAND:
return self.as_v1_expand_kwargs(add_inputs)
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
def test_correctness(self, op_type: OpType, def test_correctness(self, op_type: OpType,
@ -993,10 +687,6 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
for bench_ctx in bench_ctxs: for bench_ctx in bench_ctxs:
for seq_len in args.seq_lengths: for seq_len in args.seq_lengths:
bench_ops: list[OpType] = args.op_types bench_ops: list[OpType] = args.op_types
if seq_len > 1:
# bench only prefill ops
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
seq_len_timers = [] seq_len_timers = []
for bench_op in bench_ops: for bench_op in bench_ops:
for num_slices in bench_op.num_slices(): for num_slices in bench_op.num_slices():
@ -1206,13 +896,13 @@ Benchmark LoRA kernels:
{use_cuda_graph_recommendation()} {use_cuda_graph_recommendation()}
list_bench example: list_bench example:
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
model_bench example: model_bench example:
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
range_bench example: range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
""", # noqa: E501 """, # noqa: E501
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter)

View File

@ -477,6 +477,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. * `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `Zamba2ForCausalLM`
* Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
*
*
::: :::
:::{note} :::{note}

View File

@ -93,7 +93,6 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
max_num_seqs=2, max_num_seqs=2,
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
lora_extra_vocab_size=0,
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )

View File

@ -682,7 +682,6 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2, max_num_seqs=2,
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
lora_extra_vocab_size=0,
) )
return ModelRequestData( return ModelRequestData(

View File

@ -342,7 +342,6 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
lora_extra_vocab_size=0,
) )
placeholders = "".join(f"<|image_{i}|>" placeholders = "".join(f"<|image_{i}|>"

View File

@ -4,18 +4,13 @@ from threading import Lock
import pytest import pytest
import torch import torch
import vllm.lora.ops.triton_ops # noqa: F401 import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops.v1 # noqa: F401 import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, from vllm.lora.ops.triton_ops import LoRAKernelMeta
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import (PunicaTensors, assert_close, generate_data, from .utils import PunicaTensors, assert_close, generate_data_for_nslices
generate_data_for_expand_nslices,
generate_data_for_nslices)
# Utility shrink and expand operations used as reference implementations. # Utility shrink and expand operations used as reference implementations.
@ -26,10 +21,10 @@ def sgmv_shrink_for_nslices(
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
num_tokens: int, scaling: float): num_tokens: int, scaling: float):
""" """
Wrapper around sgmv_shrink that handles any nslices. Wrapper around torch_ops.sgmv_shrink that handles any nslices.
""" """
for index in range(nslices): for index in range(nslices):
sgmv_shrink( torch_ops.sgmv_shrink(
inputs_tensor, inputs_tensor,
lora_weights_lst[index], lora_weights_lst[index],
out_tensor[index], out_tensor[index],
@ -53,11 +48,11 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
max_seq_length: int, num_tokens: int, max_seq_length: int, num_tokens: int,
add_inputs: bool) -> None: add_inputs: bool) -> None:
""" """
Wrapper around sgmv_expand that handles any nslices. Wrapper around torch_ops.sgmv_expand that handles any nslices.
""" """
if nslices == 1: if nslices == 1:
# Verify the torch's sgmv_expand op # Verify the torch's sgmv_expand op
sgmv_expand( torch_ops.sgmv_expand(
inputs_tensor[0], inputs_tensor[0],
lora_weights_lst[0], lora_weights_lst[0],
out_tensor, out_tensor,
@ -73,7 +68,7 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
slice_offset = 0 slice_offset = 0
for index in range(nslices): for index in range(nslices):
lora_weights = lora_weights_lst[index] lora_weights = lora_weights_lst[index]
sgmv_expand_slice( torch_ops.sgmv_expand_slice(
inputs_tensor[index], inputs_tensor[index],
lora_weights, lora_weights,
out_tensor, out_tensor,
@ -93,12 +88,13 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
_dict_lock = Lock() _dict_lock = Lock()
def check_shrink_kernels(batches: int, num_loras: int, rank: int, def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype, hidden_size: int, nslices: int,
device: str, seq_length: int, scaling: float): dtype: torch.dtype, device: str, seq_length: int,
scaling: float):
""" """
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
reference implementation. kernels.
""" """
data: PunicaTensors = generate_data_for_nslices( data: PunicaTensors = generate_data_for_nslices(
batches, batches,
@ -118,35 +114,24 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int,
data.prompt_lora_mapping, batches, max_seq_length, data.prompt_lora_mapping, batches, max_seq_length,
token_nums) token_nums)
# Setup metadata information for the V1 kernel. # Setup metadata information for the LoRA kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras, lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums, max_num_tokens=token_nums,
device='cuda') device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping) lora_meta.prepare_tensors(data.token_lora_mapping)
ref_out_tensor = data.ref_out_tensor ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor out_tensor = data.our_out_tensor.clone()
v1_out_tensor = data.our_out_tensor.clone()
# Preventing cache error pointer. # Preventing cache error pointer.
with _dict_lock: with _dict_lock:
# SGMV shrink kernel # lora_shrink kernel
_LORA_A_PTR_DICT.clear() _LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink( triton_ops.lora_shrink(
data.inputs_tensor, data.inputs_tensor,
data.lora_weights, data.lora_weights,
sgmv_out_tensor, out_tensor,
*sgmv_meta_args, *lora_meta.meta_args(token_nums=token_nums),
scaling,
)
# V1 shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.v1_shrink(
data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
scaling, scaling,
) )
@ -160,16 +145,16 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int,
scaling, scaling,
) )
assert_close(sgmv_out_tensor, ref_out_tensor) assert_close(out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_expand_kernels(batches: int, num_loras: int, rank: int, def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype, hidden_size: int, nslices: int,
device: str, seq_length: int, add_inputs: bool): dtype: torch.dtype, device: str, seq_length: int,
add_inputs: bool):
""" """
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
reference implementation. kernels.
""" """
data: PunicaTensors = generate_data_for_nslices( data: PunicaTensors = generate_data_for_nslices(
batches, batches,
@ -190,37 +175,25 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int,
data.prompt_lora_mapping, batches, max_seq_length, data.prompt_lora_mapping, batches, max_seq_length,
token_nums) token_nums)
# Setup metadata information for the V1 kernel. # Setup metadata information for the LoRA kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras, lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums, max_num_tokens=token_nums,
device='cuda') device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping) lora_meta.prepare_tensors(data.token_lora_mapping)
# Setup output tensors # Setup output tensors
ref_out_tensor = data.ref_out_tensor ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor out_tensor = data.our_out_tensor.clone()
v1_out_tensor = data.our_out_tensor.clone()
with _dict_lock: with _dict_lock:
# SGMV expand kernel # lora_expand kernel
_LORA_B_PTR_DICT.clear() _LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand( triton_ops.lora_expand(data.inputs_tensor,
data.inputs_tensor, data.lora_weights,
data.lora_weights, out_tensor,
sgmv_out_tensor, *lora_meta.meta_args(token_nums=token_nums),
*sgmv_meta_args, offset_start=0,
offset_start=0, add_inputs=add_inputs)
add_inputs=add_inputs,
)
# V1 expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.v1_expand(data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
# Reference # Reference
sgmv_expand_for_nslices(nslices, sgmv_expand_for_nslices(nslices,
@ -231,124 +204,7 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int,
*sgmv_meta_args, *sgmv_meta_args,
add_inputs=add_inputs) add_inputs=add_inputs)
assert_close(sgmv_out_tensor, ref_out_tensor) assert_close(out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
scaling: float):
"""
Compare vllm.bgmv_shrink against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"shrink",
device,
)
torch.ops.vllm.bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
scaling,
)
bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
scaling,
)
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
add_inputs: bool):
"""
Compare vllm.bgmv_expand against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"expand",
device,
)
torch.ops.vllm.bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, add_inputs: bool):
"""
Compare vllm.bgmv_expand_slice against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
torch.ops.vllm.bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.our_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.ref_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
assert_close(data.our_out_tensor, data.ref_out_tensor)
# Tests # Tests
@ -490,31 +346,31 @@ def test_kernels(
op_type: str, op_type: str,
): ):
""" """
Tests SGMV and V1 kernels. Tests LoRA kernels.
""" """
torch.set_default_device(device) torch.set_default_device(device)
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
if op_type == "shrink": if op_type == "shrink":
check_shrink_kernels(batches=batches, check_lora_shrink_kernel(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,
nslices=nslices, nslices=nslices,
dtype=dtype, dtype=dtype,
device=device, device=device,
seq_length=128, seq_length=128,
scaling=0.5) scaling=0.5)
else: else:
check_expand_kernels(batches=batches, check_lora_expand_kernel(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,
nslices=nslices, nslices=nslices,
dtype=dtype, dtype=dtype,
device=device, device=device,
seq_length=128, seq_length=128,
add_inputs=True) add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches']) @pytest.mark.parametrize("batches", hs_test_params['batches'])
@ -538,159 +394,28 @@ def test_kernels_hidden_size(
op_type: str, op_type: str,
): ):
""" """
Tests SGMV and V1 kernels. Tests SGMV and LoRA kernels.
""" """
torch.set_default_device(device) torch.set_default_device(device)
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
if op_type == "shrink": if op_type == "shrink":
check_shrink_kernels(batches=batches, check_lora_shrink_kernel(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,
nslices=nslices, nslices=nslices,
dtype=dtype, dtype=dtype,
device=device, device=device,
seq_length=128, seq_length=128,
scaling=0.5) scaling=0.5)
else: else:
check_expand_kernels(batches=batches, check_lora_expand_kernel(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,
nslices=nslices, nslices=nslices,
dtype=dtype, dtype=dtype,
device=device, device=device,
seq_length=128, seq_length=128,
add_inputs=True) add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv_hidden_size(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str,
seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
rank: int, hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str, seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)

View File

@ -9,7 +9,7 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
# This test is for the hybrid models # This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev"] MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
# Bamba at Fp32 is too big for the CI (L4 GPU). # Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] # MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
@ -27,17 +27,19 @@ def test_models(
) -> None: ) -> None:
# numeric error produces different generation # numeric error produces different generation
if 'Bamba' in model: if "Bamba" in model:
example_prompts.pop(3) example_prompts.pop(3)
with hf_runner( model_kwargs = {
model, "use_mamba_kernels": False, # mamba kernels are not installed so HF
dtype=dtype, # don't use them
model_kwargs={ }
"use_mamba_kernels": if "Zamba2" in model:
False, # mamba kernels are not installed so HF # Zamba2 HF implementation automatically checks if mamba kernels are
# don't use them # installed
}) as hf_model: model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str, model: str, dtype: str,
max_tokens: int) -> None: max_tokens: int) -> None:
# numeric error during prefill chucking produces different generation # numeric error during prefill chunking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now # compared to w/o prefill chunking for those examples, removed them for now
if 'Jamba' in model: if "Jamba" in model:
example_prompts.pop(7) example_prompts.pop(7)
example_prompts.pop(2) example_prompts.pop(2)
example_prompts.pop(1) example_prompts.pop(1)
elif 'Bamba' in model: elif "Bamba" in model:
example_prompts.pop(6) example_prompts.pop(6)
example_prompts.pop(3) example_prompts.pop(3)
example_prompts.pop(2) example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
with hf_runner( model_kwargs = {
model, "use_mamba_kernels": False, # mamba kernels are not installed so HF
dtype=dtype, # don't use them
model_kwargs={ }
"use_mamba_kernels": if "Zamba2" in model:
False, # mamba kernels are not installed so HF # Zamba2 HF implementation automatically checks if mamba kernels are
# don't use them # installed
}) as hf_model: model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, with vllm_runner(model,

View File

@ -100,7 +100,6 @@ def run_test(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
lora_extra_vocab_size=0,
gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
enforce_eager=True, enforce_eager=True,
) as vllm_model: ) as vllm_model:

View File

@ -195,6 +195,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False, is_available_online=False,
trust_remote_code=True), trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
min_transformers_version="4.49"),
# [Encoder-decoder] # [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"), "BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),

View File

@ -821,6 +821,11 @@ class ModelConfig:
if qk_rope_head_dim and qk_nope_head_dim: if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim return qk_rope_head_dim + qk_nope_head_dim
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
return self.hf_text_config.attention_head_dim
if self.is_attention_free: if self.is_attention_free:
return 0 return 0
@ -904,7 +909,9 @@ class ModelConfig:
else: else:
total_num_hidden_layers = getattr(self.hf_text_config, total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0) "num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size # the layout order is: DP x PP x TP
pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size
) % parallel_config.pipeline_parallel_size
pp_size = parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return start, end return start, end
@ -942,6 +949,15 @@ class ModelConfig:
"cannot determine the num of " "cannot determine the num of "
f"{block_type.value} layers") f"{block_type.value} layers")
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value return sum(t == block_type.value
for t in layers_block_type_value[start:end]) for t in layers_block_type_value[start:end])
@ -2308,7 +2324,7 @@ class LoRAConfig:
# Setting the maximum rank to 512 should be able to satisfy the vast # Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications. # majority of applications.
possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
possible_lora_extra_vocab_size = (0, 256, 512) possible_lora_extra_vocab_size = (256, 512)
if self.max_lora_rank not in possible_max_ranks: if self.max_lora_rank not in possible_max_ranks:
raise ValueError( raise ValueError(
f"max_lora_rank ({self.max_lora_rank}) must be one of " f"max_lora_rank ({self.max_lora_rank}) must be one of "

View File

@ -897,10 +897,23 @@ def initialize_model_parallel(
get_world_group().device_group) get_world_group().device_group)
data_parallel_size = 1 data_parallel_size = 1
has_external_dp = False
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
config = get_current_vllm_config() config = get_current_vllm_config()
if config is not None: if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size if config.parallel_config.world_size != world_size:
# detect external data parallelism.
# dp in vllm means all dp instances need to run together.
# if the world size does not match, it means this dp is external,
# and the dp instances can run independently, e.g. in rlhf workflow
# from https://github.com/volcengine/verl .
# in that case, we treat the rest dimensions as if they are
# data parallel, and create a dummy dp group that is not used.
data_parallel_size = world_size // (pipeline_model_parallel_size *
tensor_model_parallel_size)
has_external_dp = True
else:
data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: DP x PP x TP # the layout order is: DP x PP x TP
# to get group_ranks for each dimension, transpose that dimension to the # to get group_ranks for each dimension, transpose that dimension to the
@ -940,6 +953,12 @@ def initialize_model_parallel(
2).reshape(-1, 2).reshape(-1,
data_parallel_size).unbind(0) data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if has_external_dp:
# create a dummy dp group that is not used actually,
# since this dp is external.
# a dummy dp group means every rank is a group itself.
# this way, no communication is needed, no memory is wasted.
group_ranks = [[x] for x in range(world_size)]
_DP = init_model_parallel_group(group_ranks, _DP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,

View File

@ -3,6 +3,7 @@
import argparse import argparse
import dataclasses import dataclasses
import json import json
import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast, get_args) Tuple, Type, Union, cast, get_args)
@ -1576,6 +1577,11 @@ class EngineArgs:
############################################################# #############################################################
# Experimental Features - allow users to opt in. # Experimental Features - allow users to opt in.
# Signal Handlers requires running in main thread.
if (threading.current_thread() != threading.main_thread()
and _warn_or_fallback("Engine in background thread")):
return False
# LoRA is supported on V1, but off by default for now. # LoRA is supported on V1, but off by default for now.
if self.enable_lora and _warn_or_fallback("LORA"): if self.enable_lora and _warn_or_fallback("LORA"):
return False return False

View File

@ -29,6 +29,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
# yapf: enable # yapf: enable
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.worker.model_runner_base import InputProcessingError from vllm.worker.model_runner_base import InputProcessingError
@ -42,12 +44,12 @@ class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`. """A multiprocessing wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to enable use This class is used to wrap the :class:`LLMEngine` class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc. receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine` generate or encode process is kicked off when a new The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket. RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests, The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal adds them to the LLMEngine if there are any, calls the internal
:class:`LLMEngine.step()`, and sends the RequestOutputs back over :class:`LLMEngine.step()`, and sends the RequestOutputs back over
@ -428,6 +430,9 @@ def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool, ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, engine_alive): disable_log_requests: bool, engine_alive):
try: try:
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
engine = MQLLMEngine.from_vllm_config( engine = MQLLMEngine.from_vllm_config(
vllm_config=vllm_config, vllm_config=vllm_config,
usage_context=usage_context, usage_context=usage_context,

View File

@ -82,6 +82,8 @@ from vllm.entrypoints.openai.serving_transcription import (
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit) is_valid_ipv6_address, set_ulimit)
@ -221,6 +223,9 @@ async def build_async_engine_client_from_engine_args(
# so we need to spawn a new process # so we need to spawn a new process
context = multiprocessing.get_context("spawn") context = multiprocessing.get_context("spawn")
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
# The Process can raise an exception during startup, which may # The Process can raise an exception during startup, which may
# not actually result in an exitcode being reported. As a result # not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information. # we use a shared variable to communicate the information.

View File

@ -1,15 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand from vllm.lora.ops.triton_ops.lora_expand import lora_expand
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.triton_ops.lora_shrink import lora_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401
__all__ = [ __all__ = [
"bgmv_expand", "lora_expand",
"bgmv_expand_slice", "lora_shrink",
"bgmv_shrink", "LoRAKernelMeta",
"sgmv_expand",
"sgmv_shrink",
] ]

View File

@ -1,188 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n_c[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride,
mask=c_mask,
other=0.0)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def _bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
def bgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
return
try:
direct_register_custom_op(
op_name="bgmv_expand",
op_func=_bgmv_expand,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_fake,
)
bgmv_expand = torch.ops.vllm.bgmv_expand
except AttributeError:
bgmv_expand = _bgmv_expand

View File

@ -1,207 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_slice_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
slice_offset * cn_stride)
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store
# operation uses the same mask, eliminating the risk of garbage
# values propagating
tiled_out = tl.load(c_ptr + current_n * cn_stride,
mask=c_mask,
other=None)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def _bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'b weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
def bgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
return
try:
direct_register_custom_op(
op_name="bgmv_expand_slice",
op_func=_bgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_slice_fake,
)
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice

View File

@ -1,168 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
scaling,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
performance
"""
pid_sk = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_n = tl.arange(0, BLOCK_N)
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
a_ptr = input_ptr + cur_batch * xm_stride
b_ptr = lora_ptr + l0_stride * lora_index
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
for k in range(0, K, BLOCK_K * SPLIT_K):
current_k = k + offset_k
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
tiled_a = tl.load(
a_ptr + current_k_c,
mask=current_k < K,
other=0.0,
) # [BLOCK_K]
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
tiled_b = tl.load(
b_ptr + offset_n[:, None] * lora_k_stride +
current_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
accumulator += tl.sum(tiled_a * tiled_b, 1)
accumulator *= scaling
offset_cn = tl.arange(0, BLOCK_N)
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
c_mask = offset_cn < N
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def _bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_a_weights.size(-1)
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
batches = lora_indices_tensor.size(0)
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
BLOCK_N = triton.next_power_of_2(N)
# First try to load optimal config from the file
config = get_lora_op_configs("bgmv_shrink", batches, K)
grid = lambda META: (
META["SPLIT_K"],
batches,
)
_bgmv_shrink_kernel[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_N=BLOCK_N,
**config,
)
return
def bgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
return
try:
direct_register_custom_op(
op_name="bgmv_shrink",
op_func=_bgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=bgmv_shrink_fake,
)
bgmv_shrink = torch.ops.vllm.bgmv_shrink
except AttributeError:
bgmv_shrink = _bgmv_shrink

View File

@ -18,7 +18,7 @@ from vllm.utils import direct_register_custom_op
@triton.jit @triton.jit
def _v1_expand_kernel( def _lora_expand_kernel(
input_ptr, input_ptr,
lora_ptr, lora_ptr,
out_ptr, out_ptr,
@ -125,7 +125,7 @@ def _v1_expand_kernel(
@torch.inference_mode() @torch.inference_mode()
def _v1_expand( def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: List[ lora_b_weights: List[
torch.Tensor], # shape [num_lora, hidden_size, lora_rank] torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
@ -216,7 +216,7 @@ def _v1_expand(
MAX_LORAS, MAX_LORAS,
) )
_v1_expand_kernel[grid]( _lora_expand_kernel[grid](
inputs, inputs,
lora_ptr_tensor, lora_ptr_tensor,
output_tensor, output_tensor,
@ -254,7 +254,7 @@ def _v1_expand(
return return
def _v1_expand_fake( def _lora_expand_fake(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor], lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
@ -271,12 +271,12 @@ def _v1_expand_fake(
try: try:
direct_register_custom_op( direct_register_custom_op(
op_name="v1_expand", op_name="lora_expand",
op_func=_v1_expand, op_func=_lora_expand,
mutates_args=["output_tensor"], mutates_args=["output_tensor"],
fake_impl=_v1_expand_fake, fake_impl=_lora_expand_fake,
) )
v1_expand = torch.ops.vllm.v1_expand lora_expand = torch.ops.vllm.lora_expand
except AttributeError: except AttributeError:
v1_expand = _v1_expand lora_expand = _lora_expand

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
V1 LoRA kernels metadata preparation utilities. LoRA kernels metadata preparation utilities.
""" """
from dataclasses import dataclass from dataclasses import dataclass
@ -10,7 +10,7 @@ import torch
@dataclass @dataclass
class V1KernelMeta: class LoRAKernelMeta:
token_lora_mapping: torch.Tensor token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor active_lora_ids: torch.Tensor
@ -19,7 +19,7 @@ class V1KernelMeta:
@staticmethod @staticmethod
def make(max_loras: int, max_num_tokens: int, def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "V1KernelMeta": device: Union[torch.device, str]) -> "LoRAKernelMeta":
token_lora_mapping = torch.empty(max_num_tokens, token_lora_mapping = torch.empty(max_num_tokens,
dtype=torch.int32, dtype=torch.int32,
@ -47,7 +47,7 @@ class V1KernelMeta:
lora_token_start_loc = torch.zeros(max_loras + 2, lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
return V1KernelMeta( return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping, token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids, active_lora_ids=active_lora_ids,
@ -105,7 +105,7 @@ class V1KernelMeta:
This function returns the kernel metadata required for the current This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the v1_shrink/v1_expand function call. unpacked directly during the lora_shrink/lora_expand function call.
Args: Args:
token_nums (int): Number of input tokens in the current forward token_nums (int): Number of input tokens in the current forward

View File

@ -18,15 +18,15 @@ from vllm.utils import direct_register_custom_op
@triton.jit @triton.jit
def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora, token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling, input_d0_stride, lora_token_start_loc, lora_ids, scaling,
input_d1_stride, lora_d0_stride, lora_d1_stride, input_d0_stride, input_d1_stride, lora_d0_stride,
lora_d2_stride, output_d0_stride, output_d1_stride, lora_d1_stride, lora_d2_stride, output_d0_stride,
output_d2_stride, BLOCK_M: tl.constexpr, output_d1_stride, output_d2_stride,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
SLICE_NUM: tl.constexpr): SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N) cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M) cta_m_num = tl.cdiv(M, BLOCK_M)
@ -96,7 +96,7 @@ def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
@torch.inference_mode() @torch.inference_mode()
def _v1_shrink( def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size] inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: List[ lora_a_weights: List[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size] torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
@ -174,7 +174,7 @@ def _v1_shrink(
MAX_LORAS, MAX_LORAS,
) )
_v1_shrink_kernel[grid]( _lora_shrink_kernel[grid](
inputs, inputs,
lora_ptr_tensor, lora_ptr_tensor,
output_tensor, output_tensor,
@ -209,7 +209,7 @@ def _v1_shrink(
return return
def _v1_shrink_fake( def _lora_shrink_fake(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor], lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
@ -225,12 +225,12 @@ def _v1_shrink_fake(
try: try:
direct_register_custom_op( direct_register_custom_op(
op_name="v1_shrink", op_name="lora_shrink",
op_func=_v1_shrink, op_func=_lora_shrink,
mutates_args=["output_tensor"], mutates_args=["output_tensor"],
fake_impl=_v1_shrink_fake, fake_impl=_lora_shrink_fake,
) )
v1_shrink = torch.ops.vllm.v1_shrink lora_shrink = torch.ops.vllm.lora_shrink
except AttributeError: except AttributeError:
v1_shrink = _v1_shrink lora_shrink = _lora_shrink

View File

@ -1,249 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .kernel_utils import do_expand_kernel
from .utils import _get_lora_b_ptr
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
slice_id = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M >= M:
return
if pid_n * BLOCK_N >= curr_N:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
m_offset = tl.load(b_seq_start_loc + cur_batch)
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
cta_m_offset = m_offset + (pid_m * BLOCK_M)
offset_m = tl.arange(0, BLOCK_M)
ram = cta_m_offset + tl.max_contiguous(
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
)
@torch.inference_mode()
def _sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (List[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == token_nums
assert inputs.size(0) == len(lora_b_weights)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert output_tensor.is_contiguous()
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
b_seq_start_loc.device)
# TODO tuning this config
K = lora_b_weights[0].shape[-1] # K= rank
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
batches,
len(lora_b_weights),
)
_sgmv_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
MAX_N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
len(lora_b_weights),
same_stride,
)
return
def _sgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="sgmv_expand",
op_func=_sgmv_expand,
mutates_args=["output_tensor"],
fake_impl=_sgmv_expand_fake,
)
sgmv_expand = torch.ops.vllm.sgmv_expand
except AttributeError:
sgmv_expand = _sgmv_expand

View File

@ -1,224 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .kernel_utils import do_shrink_kernel
from .utils import _get_lora_a_ptr
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,
lora_ptr, #1-3
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
scaling,
input_d0_stride,
input_d1_stride, # 1
lora_d0_stride,
lora_d1_stride,
lora_d2_stride, # 1
output_d0_stride,
output_d1_stride,
output_d2_stride, # 1
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance
"""
pid = tl.program_id(axis=0)
pid_mix = tl.program_id(axis=1)
cur_batch = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
if SLICE_NUM == 1:
slice_id: tl.constexpr = 0
pid_sk = tl.program_id(axis=1)
else:
pid_mix = tl.program_id(axis=1)
slice_id = pid_mix // SPLIT_K
pid_sk = pid_mix % SPLIT_K
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M >= M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
m_offset = tl.load(b_seq_start_loc + cur_batch)
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
cta_m_offset = m_offset + (pid_m * BLOCK_M)
offset_m = tl.arange(0, BLOCK_M)
ram = cta_m_offset + tl.max_contiguous(
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()
def _sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (List[torch.Tensor]): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device)
# TODO tuning this config
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 32
SPLIT_K = 8
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K * len(lora_a_weights),
batches,
)
_sgmv_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
len(lora_a_weights),
)
return
def sgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
return
try:
direct_register_custom_op(
op_name="sgmv_shrink",
op_func=_sgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=sgmv_shrink_fake,
)
sgmv_shrink = torch.ops.vllm.sgmv_shrink
except AttributeError:
sgmv_shrink = _sgmv_shrink

View File

@ -1,55 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import functools
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
@functools.lru_cache
def _get_op_configs(op_type: str, batch: int, hidden_size: int):
# TODO: add optimal configurations
return None
def _check_divisibility(hidden_size: int):
# The bgmv_expand kernel requires that the hidden_size be divisible by
# the number below.
divisibility = [2, 4, 8, 16, 32, 64]
divisibility.sort(reverse=True)
for div in divisibility:
if hidden_size % div == 0:
return div
# hidden_size is an odd number
return 1
def _get_default_config(op_type: str, batch: int, hidden_size: int):
if op_type == "expand":
return {
"BLOCK_N": 256,
"SPLIT_N": _check_divisibility(hidden_size),
"num_warps": 8
}
else:
return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8}
def get_lora_op_configs(op_type: str, batch: int,
hidden_size: int) -> Dict[str, int]:
"""Inspired by `fused_moe_kernel`
The return value will be a dictionary mapping an irregular grid of batch
sizes and hidden_size to configurations of the bgmv-related kernel.
NOTE: It currently only supports the default configuration. We plan to
generate optimal configurations for different hardware in the future using
scripts similar to `benchmark_moe.py`.
"""
config = _get_op_configs(op_type, batch, hidden_size)
if not config:
config = _get_default_config(op_type, batch, hidden_size)
return config
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} _LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}

View File

@ -1,11 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand
from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta
from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink
__all__ = [
"v1_expand",
"v1_shrink",
"V1KernelMeta",
]

View File

@ -10,20 +10,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import torch import torch
import vllm.envs as env
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
if env.VLLM_USE_V1: from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand, lora_shrink)
v1_shrink)
else:
from vllm.lora.ops.triton_ops import bgmv_expand
from vllm.lora.ops.triton_ops import bgmv_expand_slice
from vllm.lora.ops.triton_ops import bgmv_shrink
from vllm.lora.ops.triton_ops import sgmv_expand
from vllm.lora.ops.triton_ops import sgmv_shrink
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
@ -32,57 +24,8 @@ if TYPE_CHECKING:
from vllm.lora.models import LongContextLoRAContext from vllm.lora.models import LongContextLoRAContext
class V1KernelMixin:
def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int,
max_batches: int, device: Union[torch.device, str]):
self.token_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_num_batched_tokens,
device=device)
self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_batches,
device=device)
def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor,
sampler_indices: torch.Tensor):
self.token_mapping_v1_meta.prepare_tensors(token_lora_indices)
self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices)
def _v1_apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
scale: float,
):
v1_shrink(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
scale,
)
def _v1_apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
offset_start: int,
add_inputs: bool,
):
v1_expand(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
offset_start=offset_start,
add_inputs=add_inputs,
)
@final @final
class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin): class PunicaWrapperGPU(PunicaWrapperBase):
""" """
PunicaWrapperGPU is designed to manage and provide metadata for the punica PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for kernel. The main function is to maintain the state information for
@ -96,9 +39,12 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
self.max_loras = kwargs['max_loras'] self.max_loras = kwargs['max_loras']
if env.VLLM_USE_V1: self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras,
self._v1_make_metadata(self.max_loras, max_num_batched_tokens, max_num_batched_tokens,
max_batches, device) device=device)
self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_batches,
device=device)
def update_metadata( def update_metadata(
self, self,
@ -110,83 +56,18 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
long_lora_context: Optional["LongContextLoRAContext"] = None, long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs): **kwargs):
if env.VLLM_USE_V1: self.is_prefill = mapping.is_prefill
self.is_prefill = mapping.is_prefill self._update_base_metadata(mapping, lora_index_to_id, max_loras,
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size,
vocab_size, extra_vocab_size, long_lora_context)
long_lora_context)
self._v1_prepare_metadata_tensors(self.token_lora_indices,
self.sampler_indices)
else:
# Forward to base class update_metadata
PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id,
max_loras, vocab_size,
extra_vocab_size,
long_lora_context, **kwargs)
def _apply_shrink_prefill( # Prepare cuda kernel metadata tensors
self, self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
y: torch.Tensor, self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _apply_shrink_decode( def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
self, lora_a_stacked: Tuple[torch.Tensor,
y: torch.Tensor, ...], scale: float, **kwargs):
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _apply_expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
offset_start: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
offset_start=offset_start,
add_inputs=add_inputs,
)
def _apply_expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_inputs: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs):
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the When `is_prefill is` true, it indicates that it is currently the
@ -199,33 +80,24 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
y[i] += (x @ lora_a_stacked[i]) * scale y[i] += (x @ lora_a_stacked[i]) * scale
Args: Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation scale (float): Scaling factor for the operation
""" """
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
lora_shrink(
if env.VLLM_USE_V1: x,
self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore lora_a_stacked,
else: y,
if self.is_prefill: *self.token_mapping_meta.meta_args(x.size(0)),
# NOTE fused kernel scale,
self._apply_shrink_prefill( )
y, # type: ignore
x,
lora_a_stacked,
scale)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink_decode(y[slice_idx], x,
lora_a_stacked[slice_idx], scale)
def add_expand(self, def add_expand(self,
y: torch.Tensor, y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
@ -244,7 +116,7 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (torch.Tensor): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight bias's weight
@ -259,37 +131,19 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
self._apply_bias(token_lora_indices, y, output_slices, self._apply_bias(token_lora_indices, y, output_slices,
lora_bias_stacked) lora_bias_stacked)
if env.VLLM_USE_V1: assert x.ndim == 3
# TODO (varun): Profile with add_inputs = False. i.e. move the assert x.size(0) == len(output_slices)
# addition out of the kernel num_tokens = x.size(1) # first dimension is the num slices
self._v1_apply_expand(
y, lora_expand(
x, # type: ignore x,
lora_b_stacked, lora_b_stacked,
offset_start, y,
add_inputs=True) *self.token_mapping_meta.meta_args(num_tokens),
else: offset_start=offset_start,
add_inputs=True,
)
if self.is_prefill:
# NOTE fused kernel
self._apply_expand_prefill(
y,
x, # type: ignore
lora_b_stacked,
offset_start,
add_inputs=True)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand_decode(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_start,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_start += output_slices[slice_idx]
y = y.view_as(y_org) y = y.view_as(y_org)
def add_lora_embedding(self, def add_lora_embedding(self,
@ -311,24 +165,14 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
add_inputs (bool): Default to True. add_inputs (bool): Default to True.
""" """
if env.VLLM_USE_V1: lora_expand(
self._v1_apply_expand(y, x.unsqueeze(dim=0),
x.unsqueeze(dim=0), (lora_b_stacked, ), (lora_b_stacked, ),
offset_start=0, y,
add_inputs=add_inputs) *self.token_mapping_meta.meta_args(x.size(0)),
else: offset_start=0,
if self.is_prefill: add_inputs=add_inputs,
sgmv_expand( )
x.unsqueeze(dim=0),
(lora_b_stacked, ),
y,
*self.prefill_metadata,
offset_start=0,
add_inputs=add_inputs,
)
else:
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
add_inputs)
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
@ -339,7 +183,7 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: Tuple[int, ...],
*, *,
buffer: Optional[Tuple[torch.Tensor, ...]] = None, buffer: Optional[torch.Tensor] = None,
**kwargs) -> None: **kwargs) -> None:
""" """
Applicable to linear-related lora. Applicable to linear-related lora.
@ -361,7 +205,7 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size. output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[torch.Tensor]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
@ -431,21 +275,11 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
if env.VLLM_USE_V1: lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), *self.prompt_mapping_meta.meta_args(x.size(0)), scale)
*self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale)
v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
y, y,
*self.prompt_mapping_v1_meta.meta_args(buffer.size(0)), *self.prompt_mapping_meta.meta_args(buffer.size(0)),
add_inputs=True) add_inputs=True)
else:
# V0 LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org) y = y.view_as(y_org)

View File

@ -245,7 +245,6 @@ class MambaMixer2(CustomOp):
assert num_heads % self.tp_size == 0, \ assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads." "Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
( (
"If tensor parallel world size does not divide num_heads, " "If tensor parallel world size does not divide num_heads, "

View File

@ -10,7 +10,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
@ -140,6 +139,10 @@ def _fused_moe_gguf(
qweight_type2: int, qweight_type2: int,
act, act,
) -> torch.Tensor: ) -> torch.Tensor:
# lazy import to avoid triggering triton import in CPU backend
from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size)
out_hidden_states = torch.empty_like(x) out_hidden_states = torch.empty_like(x)
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES: if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
num_tokens, _ = x.shape num_tokens, _ = x.shape

View File

@ -38,8 +38,6 @@ from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class BambaMLP(nn.Module): class BambaMLP(nn.Module):

View File

@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel from .blip import BlipVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
@ -565,12 +565,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported. pixel_values = flatten_bn(pixel_values, concat=True)
pixel_values = pixel_values.squeeze(1)
return Blip2ImagePixelInputs( return Blip2ImagePixelInputs(
type="pixel_values", type="pixel_values",
@ -578,12 +577,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported. image_embeds = flatten_bn(image_embeds, concat=True)
image_embeds = image_embeds.squeeze(1)
return Blip2ImageEmbeddingInputs( return Blip2ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",

View File

@ -39,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
@ -972,12 +972,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None: if pixel_values is None:
return None return None
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported. pixel_values = flatten_bn(pixel_values, concat=True)
pixel_values = pixel_values.squeeze(1)
return ChameleonImagePixelInputs( return ChameleonImagePixelInputs(
type="pixel_values", type="pixel_values",

View File

@ -478,7 +478,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
flatten_bn(images_spatial_crop, concat=True))) flatten_bn(images_spatial_crop, concat=True)))
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")

View File

@ -578,7 +578,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")

View File

@ -838,7 +838,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return None return None
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
@ -856,7 +856,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}") f"Got type: {type(pixel_values_flat)}")
assert isinstance(image_num_patches, (torch.Tensor, list)) if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(pixel_values_flat)}")
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",

View File

@ -36,8 +36,6 @@ from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class JambaMoE(nn.Module): class JambaMoE(nn.Module):

View File

@ -349,21 +349,18 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
List[b, Tensor(nb_frames, nb_channels, height, width)] List[b, Tensor(nb_frames, nb_channels, height, width)]
} }
""" """
pixel_values = kwargs.pop("pixel_values_videos", None) pixel_values_videos = kwargs.pop("pixel_values_videos", None)
if pixel_values is None: if pixel_values_videos is None:
return None return None
if not (is_list_of(pixel_values, if not isinstance(pixel_values_videos, (torch.Tensor, list)):
(torch.Tensor)) # different shape videos raise ValueError("Incorrect type of pixel_values_videos. "
or isinstance(pixel_values, f"Got type: {type(pixel_values_videos)}")
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaNextVideoPixelInputs( return LlavaNextVideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
data=pixel_values, data=pixel_values_videos,
) )
def _select_image_features(self, image_features: torch.Tensor, *, def _select_image_features(self, image_features: torch.Tensor, *,

View File

@ -574,10 +574,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values_videos is None: if pixel_values_videos is None:
return None return None
if not (is_list_of(pixel_values_videos, if not isinstance(pixel_values_videos, (torch.Tensor, list)):
torch.Tensor) # different shape videos
or isinstance(pixel_values_videos,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel_values_videos. " raise ValueError("Incorrect type of pixel_values_videos. "
f"Got type: {type(pixel_values_videos)}") f"Got type: {type(pixel_values_videos)}")

View File

@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -283,17 +283,19 @@ class Olmo2Model(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
""" """
:param input_ids: A tensor of shape `(batch_size, seq_len)`. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
""" """
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# Get embeddings of input. # Get embeddings of input.
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
inputs_embeds = self.embed_tokens(input_ids) else:
hidden_states = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
@ -337,7 +339,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
@ -346,11 +348,13 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states

View File

@ -23,7 +23,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
@ -270,12 +270,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported. pixel_values = flatten_bn(pixel_values, concat=True)
pixel_values = pixel_values.squeeze(1)
return PaliGemmaImagePixelInputs( return PaliGemmaImagePixelInputs(
type="pixel_values", type="pixel_values",
@ -287,8 +286,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported. image_embeds = flatten_bn(image_embeds, concat=True)
image_embeds = image_embeds.squeeze(1)
return PaliGemmaImageEmbeddingInputs( return PaliGemmaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",

View File

@ -711,7 +711,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
@ -722,13 +722,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
return QwenImageEmbeddingInputs( return QwenImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=flatten_bn(image_embeds, concat=True),
) )
return None return None

View File

@ -105,6 +105,7 @@ _TEXT_GENERATION_MODELS = {
"SolarForCausalLM": ("solar", "SolarForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XverseForCausalLM": ("llama", "LlamaForCausalLM"), "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
# [Encoder-decoder] # [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),

File diff suppressed because it is too large Load Diff

View File

@ -62,9 +62,10 @@ class LoRAModelRunnerMixin:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
# Set is_prefill to True, so we always use the SGMV kernels. # Set is_prefill to True, so we always use the SGMV kernels on
# For cuda platforms, we have specialized triton kernels, and # non-cuda platforms.
# the cuda path ignores `is_prefill`. # On cuda platforms we use the same kernels for prefill and
# decode and this flag is generally ignored.
lora_mapping = LoRAMapping(token_lora_mapping, lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping, prompt_lora_mapping,
is_prefill=True) is_prefill=True)