mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
9 Commits
d31f7844f8
...
v0.8.0
Author | SHA1 | Date | |
---|---|---|---|
966f933ee1 | |||
1a504aff6c | |||
01ca85bbd8 | |||
d82b9487ea | |||
be13281d4b | |||
54e084f7fb | |||
9e8f089d08 | |||
16e9064f84 | |||
5ac1a8e6e4 |
@ -17,13 +17,8 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
@ -167,69 +162,25 @@ class OpType(Enum):
|
||||
"""
|
||||
LoRA Ops to benchmark and its properties.
|
||||
"""
|
||||
SGMV_SHRINK = auto()
|
||||
BGMV_SHRINK = auto()
|
||||
SGMV_EXPAND = auto()
|
||||
BGMV_EXPAND = auto()
|
||||
BGMV_EXPAND_SLICE = auto()
|
||||
V1_SHRINK = auto()
|
||||
V1_EXPAND = auto()
|
||||
LORA_SHRINK = auto()
|
||||
LORA_EXPAND = auto()
|
||||
|
||||
@staticmethod
|
||||
def from_str(s: str) -> "OpType":
|
||||
if s.lower() == 'sgmv_shrink':
|
||||
return OpType.SGMV_SHRINK
|
||||
if s.lower() == 'sgmv_expand':
|
||||
return OpType.SGMV_EXPAND
|
||||
if s.lower() == 'bgmv_shrink':
|
||||
return OpType.BGMV_SHRINK
|
||||
if s.lower() == 'bgmv_expand':
|
||||
return OpType.BGMV_EXPAND
|
||||
if s.lower() == "bgmv_expand_slice":
|
||||
return OpType.BGMV_EXPAND_SLICE
|
||||
if s.lower() == "v1_shrink":
|
||||
return OpType.V1_SHRINK
|
||||
if s.lower() == "v1_expand":
|
||||
return OpType.V1_EXPAND
|
||||
if s.lower() == "lora_shrink":
|
||||
return OpType.LORA_SHRINK
|
||||
if s.lower() == "lora_expand":
|
||||
return OpType.LORA_EXPAND
|
||||
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
||||
|
||||
def is_shrink_fn(self) -> bool:
|
||||
return self in [
|
||||
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
|
||||
]
|
||||
return self in [OpType.LORA_SHRINK]
|
||||
|
||||
def is_expand_fn(self) -> bool:
|
||||
return self in [
|
||||
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
|
||||
]
|
||||
|
||||
def is_prefill_op(self) -> bool:
|
||||
return self in [
|
||||
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
|
||||
OpType.V1_EXPAND
|
||||
]
|
||||
|
||||
def is_decode_op(self) -> bool:
|
||||
return self in [
|
||||
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
|
||||
OpType.V1_SHRINK, OpType.V1_EXPAND
|
||||
]
|
||||
|
||||
def is_expand_slice_fn(self) -> bool:
|
||||
return self in [OpType.BGMV_EXPAND_SLICE]
|
||||
return self in [OpType.LORA_EXPAND]
|
||||
|
||||
def num_slices(self) -> list[int]:
|
||||
if self in [
|
||||
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
|
||||
OpType.V1_EXPAND
|
||||
]:
|
||||
# SGMV kernels and v1 kernels supports slices
|
||||
return [1, 2, 3]
|
||||
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
|
||||
return [1]
|
||||
if self in [OpType.BGMV_EXPAND_SLICE]:
|
||||
return [2, 3]
|
||||
raise ValueError(f"Unrecognized OpType {self}")
|
||||
|
||||
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
||||
lora_rank: int) -> tuple[int, int, int]:
|
||||
@ -239,7 +190,7 @@ class OpType(Enum):
|
||||
k = hidden_size
|
||||
n = lora_rank
|
||||
else:
|
||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
||||
assert self.is_expand_fn()
|
||||
m = num_tokens
|
||||
k = lora_rank
|
||||
n = hidden_size
|
||||
@ -254,7 +205,7 @@ class OpType(Enum):
|
||||
if self.is_shrink_fn():
|
||||
return op_dtype, op_dtype, torch.float32
|
||||
else:
|
||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
||||
assert self.is_expand_fn()
|
||||
return torch.float32, op_dtype, op_dtype
|
||||
|
||||
def matmul_shapes(
|
||||
@ -268,43 +219,19 @@ class OpType(Enum):
|
||||
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
||||
|
||||
b_shape = (num_loras, n, k) # col-major
|
||||
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
|
||||
# SGMV shrink and V1 shrink kernels support num_slices inherently
|
||||
# in the kernel.
|
||||
if self in [OpType.LORA_SHRINK]:
|
||||
# LoRA shrink kernels support num_slices inherently in the kernel.
|
||||
return ((m, k), b_shape, (num_slices, m, n))
|
||||
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
|
||||
# SGMV expand and V1 expand kernels support num_slices inherently
|
||||
# in the kernel
|
||||
if self in [OpType.LORA_EXPAND]:
|
||||
# LoRA expand kernels support num_slices inherently in the kernel
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
return ((m, k), b_shape, (m, n))
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
return ((m, k), b_shape, (m, n))
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
|
||||
raise ValueError(f"Unrecognized op_type {self}")
|
||||
|
||||
def bench_fn(self) -> Callable:
|
||||
|
||||
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
|
||||
for x in kwargs_list:
|
||||
bgmv_expand_slice(**x)
|
||||
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
return sgmv_shrink
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
return sgmv_expand
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
return bgmv_shrink
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
return bgmv_expand
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
return emulate_bgmv_expand_slice
|
||||
if self == OpType.V1_SHRINK:
|
||||
return v1_shrink
|
||||
if self == OpType.V1_EXPAND:
|
||||
return v1_expand
|
||||
if self == OpType.LORA_SHRINK:
|
||||
return lora_shrink
|
||||
if self == OpType.LORA_EXPAND:
|
||||
return lora_expand
|
||||
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
@ -318,34 +245,13 @@ class OpType(Enum):
|
||||
"""
|
||||
w_dtype = lora_weights[0].dtype
|
||||
num_slices = len(lora_weights)
|
||||
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
|
||||
if self in [OpType.LORA_SHRINK]:
|
||||
for slice_idx in range(num_slices):
|
||||
ref_group_gemm(ref_out=output[slice_idx, :],
|
||||
input=input,
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
ref_group_gemm(
|
||||
ref_out=output[:, slice_offset:slice_offset + hidden_size],
|
||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
elif self == OpType.BGMV_SHRINK:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input,
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
elif self == OpType.BGMV_EXPAND:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input.clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
elif self == OpType.BGMV_EXPAND_SLICE:
|
||||
elif self in [OpType.LORA_EXPAND]:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
@ -411,13 +317,11 @@ class BenchmarkTensors:
|
||||
input: torch.Tensor
|
||||
lora_weights_lst: list[torch.Tensor]
|
||||
output: torch.Tensor
|
||||
# metadata tensors
|
||||
# LoRA kernel metadata
|
||||
lora_kernel_meta: LoRAKernelMeta
|
||||
# Metadata tensors used in testing correctness
|
||||
seq_lens: torch.Tensor
|
||||
seq_start_loc: torch.Tensor
|
||||
prompt_lora_mapping: torch.Tensor
|
||||
token_lora_mapping: torch.Tensor
|
||||
# v1 kernel metadata
|
||||
v1_kernel_meta: Optional[V1KernelMeta] = None
|
||||
|
||||
def io_types(self) -> str:
|
||||
return (f"{dtype_to_str(self.input.dtype)}x"
|
||||
@ -444,35 +348,29 @@ class BenchmarkTensors:
|
||||
assert ctx.num_active_loras <= ctx.num_loras
|
||||
total_tokens = ctx.batch_size * ctx.seq_length
|
||||
|
||||
# Make metadata tensors involved in correctness testing.
|
||||
# Prepare seq lens tensor
|
||||
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
|
||||
(ctx.batch_size, ))
|
||||
# Prepare seq_start_loc tensor
|
||||
seq_start_loc_tensor = torch.cumsum(torch.tensor(
|
||||
[0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
||||
dim=0)
|
||||
assert total_tokens == seq_len_tensor.sum()
|
||||
# Prepare prompt lora indices tensor
|
||||
prompt_lora_indices_tensor = make_prompt_lora_mapping(
|
||||
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
|
||||
# Prepare token lora indices tensor
|
||||
|
||||
# Make LoRAKernelMeta
|
||||
token_lora_indices_tensor = make_token_lora_mapping(
|
||||
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
|
||||
seq_len_tensor, "cpu")
|
||||
|
||||
v1_kernel_meta = None
|
||||
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
|
||||
v1_kernel_meta = V1KernelMeta.make(
|
||||
lora_kernel_meta = LoRAKernelMeta.make(
|
||||
max_loras=ctx.num_loras,
|
||||
max_num_tokens=token_lora_indices_tensor.size(0),
|
||||
device="cpu")
|
||||
v1_kernel_meta.prepare_tensors(
|
||||
lora_kernel_meta.prepare_tensors(
|
||||
token_lora_mapping=token_lora_indices_tensor)
|
||||
|
||||
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
|
||||
seq_len_tensor, seq_start_loc_tensor,
|
||||
prompt_lora_indices_tensor,
|
||||
token_lora_indices_tensor, v1_kernel_meta)
|
||||
lora_kernel_meta, seq_len_tensor,
|
||||
prompt_lora_indices_tensor)
|
||||
|
||||
def sanity_check(self) -> None:
|
||||
"""
|
||||
@ -482,9 +380,9 @@ class BenchmarkTensors:
|
||||
# check metadata tensors
|
||||
assert torch.sum(self.seq_lens) == num_tokens
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
assert self.seq_start_loc.shape[0] == num_seqs
|
||||
#assert self.seq_start_loc.shape[0] == num_seqs
|
||||
assert self.prompt_lora_mapping.shape[0] == num_seqs
|
||||
assert self.token_lora_mapping.shape[0] == num_tokens
|
||||
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
|
||||
|
||||
def to_device(self, device: str):
|
||||
"""
|
||||
@ -499,220 +397,27 @@ class BenchmarkTensors:
|
||||
self.input = to_device(self.input)
|
||||
self.output = to_device(self.output)
|
||||
self.seq_lens = to_device(self.seq_lens)
|
||||
self.seq_start_loc = to_device(self.seq_start_loc)
|
||||
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
|
||||
self.token_lora_mapping = to_device(self.token_lora_mapping)
|
||||
for i in range(len(self.lora_weights_lst)):
|
||||
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
||||
|
||||
# v1 meta
|
||||
if self.v1_kernel_meta:
|
||||
for field_name in V1KernelMeta.__dataclass_fields__:
|
||||
field = getattr(self.v1_kernel_meta, field_name)
|
||||
# LoRA meta
|
||||
for field_name in LoRAKernelMeta.__dataclass_fields__:
|
||||
field = getattr(self.lora_kernel_meta, field_name)
|
||||
assert isinstance(field, torch.Tensor)
|
||||
setattr(self.v1_kernel_meta, field_name, to_device(field))
|
||||
setattr(self.lora_kernel_meta, field_name, to_device(field))
|
||||
|
||||
def metadata(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
Return num_seqs, num_tokens and max_seq_len
|
||||
"""
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
num_tokens = self.token_lora_mapping.shape[0]
|
||||
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
|
||||
max_seq_len = torch.max(self.seq_lens).item()
|
||||
num_slices = len(self.lora_weights_lst)
|
||||
return num_seqs, num_tokens, max_seq_len, num_slices
|
||||
|
||||
def convert_to_sgmv_benchmark_tensors(self):
|
||||
"""
|
||||
For sgmv punica kernels, when consecutive sequences have the
|
||||
same LoRA ID, we just merge them together.
|
||||
This happens in punica.py::compute_metadata
|
||||
"""
|
||||
|
||||
# Collapse seq_lens and seq_start_loc
|
||||
_, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
|
||||
return_counts=True)
|
||||
cum_result = torch.cumsum(seq_lens, dim=0)
|
||||
seq_start_loc = torch.zeros_like(seq_lens)
|
||||
seq_start_loc[1:].copy_(cum_result[:-1])
|
||||
|
||||
# Collapse prompt mapping
|
||||
prompt_lora_mapping = torch.unique_consecutive(
|
||||
self.prompt_lora_mapping)
|
||||
|
||||
assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
|
||||
f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"
|
||||
|
||||
self.prompt_lora_mapping = prompt_lora_mapping.to(
|
||||
dtype=self.prompt_lora_mapping.dtype)
|
||||
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
||||
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
||||
|
||||
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, hidden_size]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
hidden_size = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == hidden_size
|
||||
lora_rank = lw_shape[1]
|
||||
# Expected output shape [num_slices, num_tokens, lora_rank]
|
||||
assert len(o_shape) == 3
|
||||
assert o_shape == (num_slices, num_tokens, lora_rank)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'b_seq_start_loc': self.seq_start_loc,
|
||||
'seq_len_tensor': self.seq_lens,
|
||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
||||
'batches': num_seqs,
|
||||
'max_seq_length': max_seq_len,
|
||||
'token_nums': num_tokens,
|
||||
'scaling': 1.0,
|
||||
}
|
||||
|
||||
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape : [num_slices, num_tokens, lora_rank]
|
||||
assert len(i_shape) == 3
|
||||
assert i_shape[0] == num_slices
|
||||
assert i_shape[1] == num_tokens
|
||||
lora_rank = i_shape[2]
|
||||
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape : [num_tokens, hidden_size * num_slices]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size * num_slices)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'b_seq_start_loc': self.seq_start_loc,
|
||||
'seq_len_tensor': self.seq_lens,
|
||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
||||
'batches': num_seqs,
|
||||
'max_seq_length': max_seq_len,
|
||||
'token_nums': num_tokens,
|
||||
'offset_start': 0,
|
||||
'add_inputs': add_inputs,
|
||||
}
|
||||
|
||||
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
assert len(self.lora_weights_lst) == 1
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, _ = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, hidden_size]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
hidden_size = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == hidden_size
|
||||
lora_rank = lw_shape[1]
|
||||
# Expected output shape [num_tokens, lora_rank]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, lora_rank)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst[0],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'scaling': 1.0
|
||||
}
|
||||
|
||||
def as_bgmv_expand_kwargs(self, add_inputs: bool):
|
||||
assert len(self.lora_weights_lst) == 1
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, _ = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, lora_rank]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
lora_rank = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape [num_tokens, hidden_size]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst[0],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'add_inputs': add_inputs
|
||||
}
|
||||
|
||||
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_slices, num_tokens, lora_rank]
|
||||
assert len(i_shape) == 3
|
||||
assert i_shape[0] == num_slices
|
||||
assert i_shape[1] == num_tokens
|
||||
lora_rank = i_shape[2]
|
||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape [num_tokens, hidden_size * num_slices]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size * num_slices)
|
||||
|
||||
self.to_device(self.input.device)
|
||||
|
||||
kwargs_list = []
|
||||
for i in range(num_slices):
|
||||
kwargs_list.append({
|
||||
'inputs': self.input[i],
|
||||
'lora_b_weights': self.lora_weights_lst[i],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'slice_offset': i * hidden_size,
|
||||
'slice_size': hidden_size,
|
||||
'add_inputs': add_inputs,
|
||||
})
|
||||
return {'kwargs_list': kwargs_list}
|
||||
|
||||
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
|
||||
assert self.v1_kernel_meta is not None
|
||||
def as_lora_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
@ -737,17 +442,16 @@ class BenchmarkTensors:
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
|
||||
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.v1_kernel_meta.active_lora_ids,
|
||||
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||
'scaling': 1.0,
|
||||
}
|
||||
|
||||
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
assert self.v1_kernel_meta is not None
|
||||
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
@ -773,12 +477,12 @@ class BenchmarkTensors:
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
|
||||
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.v1_kernel_meta.active_lora_ids,
|
||||
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||
'offset_start': 0,
|
||||
'add_inputs': add_inputs,
|
||||
}
|
||||
@ -791,20 +495,10 @@ class BenchmarkTensors:
|
||||
else:
|
||||
assert add_inputs is not None
|
||||
|
||||
if op_type == OpType.SGMV_SHRINK:
|
||||
return self.as_sgmv_shrink_kwargs()
|
||||
if op_type == OpType.SGMV_EXPAND:
|
||||
return self.as_sgmv_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.BGMV_SHRINK:
|
||||
return self.as_bgmv_shrink_kwargs()
|
||||
if op_type == OpType.BGMV_EXPAND:
|
||||
return self.as_bgmv_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.BGMV_EXPAND_SLICE:
|
||||
return self.as_bgmv_expand_slice_kwargs(add_inputs)
|
||||
if op_type == OpType.V1_SHRINK:
|
||||
return self.as_v1_shrink_kwargs()
|
||||
if op_type == OpType.V1_EXPAND:
|
||||
return self.as_v1_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.LORA_SHRINK:
|
||||
return self.as_lora_shrink_kwargs()
|
||||
if op_type == OpType.LORA_EXPAND:
|
||||
return self.as_lora_expand_kwargs(add_inputs)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def test_correctness(self, op_type: OpType,
|
||||
@ -993,10 +687,6 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
||||
for bench_ctx in bench_ctxs:
|
||||
for seq_len in args.seq_lengths:
|
||||
bench_ops: list[OpType] = args.op_types
|
||||
if seq_len > 1:
|
||||
# bench only prefill ops
|
||||
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
|
||||
|
||||
seq_len_timers = []
|
||||
for bench_op in bench_ops:
|
||||
for num_slices in bench_op.num_slices():
|
||||
@ -1206,13 +896,13 @@ Benchmark LoRA kernels:
|
||||
{use_cuda_graph_recommendation()}
|
||||
|
||||
list_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
|
||||
model_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
|
||||
range_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
||||
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
|
@ -477,6 +477,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `Zamba2ForCausalLM`
|
||||
* Zamba2
|
||||
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
|
||||
*
|
||||
*
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
|
@ -93,7 +93,6 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
|
@ -682,7 +682,6 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
|
@ -342,7 +342,6 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
)
|
||||
|
||||
placeholders = "".join(f"<|image_{i}|>"
|
||||
|
@ -4,18 +4,13 @@ from threading import Lock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
import vllm.lora.ops.triton_ops.v1 # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
import vllm.lora.ops.torch_ops as torch_ops
|
||||
import vllm.lora.ops.triton_ops as triton_ops
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (PunicaTensors, assert_close, generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
generate_data_for_nslices)
|
||||
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
|
||||
|
||||
|
||||
# Utility shrink and expand operations used as reference implementations.
|
||||
@ -26,10 +21,10 @@ def sgmv_shrink_for_nslices(
|
||||
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
|
||||
num_tokens: int, scaling: float):
|
||||
"""
|
||||
Wrapper around sgmv_shrink that handles any nslices.
|
||||
Wrapper around torch_ops.sgmv_shrink that handles any nslices.
|
||||
"""
|
||||
for index in range(nslices):
|
||||
sgmv_shrink(
|
||||
torch_ops.sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst[index],
|
||||
out_tensor[index],
|
||||
@ -53,11 +48,11 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
max_seq_length: int, num_tokens: int,
|
||||
add_inputs: bool) -> None:
|
||||
"""
|
||||
Wrapper around sgmv_expand that handles any nslices.
|
||||
Wrapper around torch_ops.sgmv_expand that handles any nslices.
|
||||
"""
|
||||
if nslices == 1:
|
||||
# Verify the torch's sgmv_expand op
|
||||
sgmv_expand(
|
||||
torch_ops.sgmv_expand(
|
||||
inputs_tensor[0],
|
||||
lora_weights_lst[0],
|
||||
out_tensor,
|
||||
@ -73,7 +68,7 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
sgmv_expand_slice(
|
||||
torch_ops.sgmv_expand_slice(
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
out_tensor,
|
||||
@ -93,12 +88,13 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
def check_shrink_kernels(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, scaling: float):
|
||||
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int,
|
||||
dtype: torch.dtype, device: str, seq_length: int,
|
||||
scaling: float):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
|
||||
reference implementation.
|
||||
Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
|
||||
kernels.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
@ -118,35 +114,24 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
|
||||
# Setup metadata information for the V1 kernel.
|
||||
v1_meta = V1KernelMeta.make(max_loras=num_loras,
|
||||
# Setup metadata information for the LoRA kernel.
|
||||
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
v1_meta.prepare_tensors(data.token_lora_mapping)
|
||||
lora_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
ref_out_tensor = data.ref_out_tensor
|
||||
sgmv_out_tensor = data.our_out_tensor
|
||||
v1_out_tensor = data.our_out_tensor.clone()
|
||||
out_tensor = data.our_out_tensor.clone()
|
||||
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
# SGMV shrink kernel
|
||||
# lora_shrink kernel
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
triton_ops.lora_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
sgmv_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
scaling,
|
||||
)
|
||||
|
||||
# V1 shrink kernel
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.v1_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
v1_out_tensor,
|
||||
*v1_meta.meta_args(token_nums=token_nums),
|
||||
out_tensor,
|
||||
*lora_meta.meta_args(token_nums=token_nums),
|
||||
scaling,
|
||||
)
|
||||
|
||||
@ -160,16 +145,16 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int,
|
||||
scaling,
|
||||
)
|
||||
|
||||
assert_close(sgmv_out_tensor, ref_out_tensor)
|
||||
assert_close(v1_out_tensor, ref_out_tensor)
|
||||
assert_close(out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
def check_expand_kernels(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, add_inputs: bool):
|
||||
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int,
|
||||
dtype: torch.dtype, device: str, seq_length: int,
|
||||
add_inputs: bool):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
|
||||
reference implementation.
|
||||
Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
|
||||
kernels.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
@ -190,35 +175,23 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
|
||||
# Setup metadata information for the V1 kernel.
|
||||
v1_meta = V1KernelMeta.make(max_loras=num_loras,
|
||||
# Setup metadata information for the LoRA kernel.
|
||||
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
v1_meta.prepare_tensors(data.token_lora_mapping)
|
||||
lora_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
# Setup output tensors
|
||||
ref_out_tensor = data.ref_out_tensor
|
||||
sgmv_out_tensor = data.our_out_tensor
|
||||
v1_out_tensor = data.our_out_tensor.clone()
|
||||
out_tensor = data.our_out_tensor.clone()
|
||||
|
||||
with _dict_lock:
|
||||
# SGMV expand kernel
|
||||
# lora_expand kernel
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
data.inputs_tensor,
|
||||
triton_ops.lora_expand(data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
sgmv_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
# V1 expand kernel
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.v1_expand(data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
v1_out_tensor,
|
||||
*v1_meta.meta_args(token_nums=token_nums),
|
||||
out_tensor,
|
||||
*lora_meta.meta_args(token_nums=token_nums),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs)
|
||||
|
||||
@ -231,124 +204,7 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int,
|
||||
*sgmv_meta_args,
|
||||
add_inputs=add_inputs)
|
||||
|
||||
assert_close(sgmv_out_tensor, ref_out_tensor)
|
||||
assert_close(v1_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, dtype: torch.dtype, device: str,
|
||||
scaling: float):
|
||||
"""
|
||||
Compare vllm.bgmv_shrink against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
"shrink",
|
||||
device,
|
||||
)
|
||||
|
||||
torch.ops.vllm.bgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
scaling,
|
||||
)
|
||||
|
||||
bgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
scaling,
|
||||
)
|
||||
|
||||
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, dtype: torch.dtype, device: str,
|
||||
add_inputs: bool):
|
||||
"""
|
||||
Compare vllm.bgmv_expand against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
"expand",
|
||||
device,
|
||||
)
|
||||
|
||||
torch.ops.vllm.bgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
bgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, add_inputs: bool):
|
||||
"""
|
||||
Compare vllm.bgmv_expand_slice against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data_for_expand_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
torch.ops.vllm.bgmv_expand_slice(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights[index],
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
bgmv_expand_slice(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights[index],
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
assert_close(out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
# Tests
|
||||
@ -490,13 +346,13 @@ def test_kernels(
|
||||
op_type: str,
|
||||
):
|
||||
"""
|
||||
Tests SGMV and V1 kernels.
|
||||
Tests LoRA kernels.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_shrink_kernels(batches=batches,
|
||||
check_lora_shrink_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
@ -506,7 +362,7 @@ def test_kernels(
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_expand_kernels(batches=batches,
|
||||
check_lora_expand_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
@ -538,13 +394,13 @@ def test_kernels_hidden_size(
|
||||
op_type: str,
|
||||
):
|
||||
"""
|
||||
Tests SGMV and V1 kernels.
|
||||
Tests SGMV and LoRA kernels.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_shrink_kernels(batches=batches,
|
||||
check_lora_shrink_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
@ -554,7 +410,7 @@ def test_kernels_hidden_size(
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_expand_kernels(batches=batches,
|
||||
check_lora_expand_kernel(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
@ -563,134 +419,3 @@ def test_kernels_hidden_size(
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_bgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_bgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_bgmv_hidden_size(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_bgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_bgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int,
|
||||
dtype: torch.dtype, device: str,
|
||||
seed: int):
|
||||
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
check_bgmv_expand_slice(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
|
||||
rank: int, hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str, seed: int):
|
||||
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
check_bgmv_expand_slice(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
|
@ -9,7 +9,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
# This test is for the hybrid models
|
||||
MODELS = ["ai21labs/Jamba-tiny-dev"]
|
||||
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
|
||||
# Bamba at Fp32 is too big for the CI (L4 GPU).
|
||||
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
|
||||
|
||||
@ -27,17 +27,19 @@ def test_models(
|
||||
) -> None:
|
||||
|
||||
# numeric error produces different generation
|
||||
if 'Bamba' in model:
|
||||
if "Bamba" in model:
|
||||
example_prompts.pop(3)
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
model_kwargs={
|
||||
"use_mamba_kernels":
|
||||
False, # mamba kernels are not installed so HF
|
||||
model_kwargs = {
|
||||
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}) as hf_model:
|
||||
}
|
||||
if "Zamba2" in model:
|
||||
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||
# installed
|
||||
model_kwargs = {}
|
||||
|
||||
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
|
||||
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# numeric error during prefill chucking produces different generation
|
||||
# numeric error during prefill chunking produces different generation
|
||||
# compared to w/o prefill chunking for those examples, removed them for now
|
||||
if 'Jamba' in model:
|
||||
if "Jamba" in model:
|
||||
example_prompts.pop(7)
|
||||
example_prompts.pop(2)
|
||||
example_prompts.pop(1)
|
||||
elif 'Bamba' in model:
|
||||
elif "Bamba" in model:
|
||||
example_prompts.pop(6)
|
||||
example_prompts.pop(3)
|
||||
example_prompts.pop(2)
|
||||
dtype = "half" # use a different dtype for Bamba
|
||||
elif "Zamba2" in model:
|
||||
example_prompts.pop(7)
|
||||
dtype = "half"
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
model_kwargs={
|
||||
"use_mamba_kernels":
|
||||
False, # mamba kernels are not installed so HF
|
||||
model_kwargs = {
|
||||
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}) as hf_model:
|
||||
}
|
||||
if "Zamba2" in model:
|
||||
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||
# installed
|
||||
model_kwargs = {}
|
||||
|
||||
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
|
@ -100,7 +100,6 @@ def run_test(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
|
@ -195,6 +195,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
|
||||
is_available_online=False,
|
||||
trust_remote_code=True),
|
||||
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
|
||||
min_transformers_version="4.49"),
|
||||
# [Encoder-decoder]
|
||||
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
||||
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
||||
|
@ -821,6 +821,11 @@ class ModelConfig:
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
return self.hf_text_config.attention_head_dim
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
@ -904,7 +909,9 @@ class ModelConfig:
|
||||
else:
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
||||
# the layout order is: DP x PP x TP
|
||||
pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size
|
||||
) % parallel_config.pipeline_parallel_size
|
||||
pp_size = parallel_config.pipeline_parallel_size
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
return start, end
|
||||
@ -942,6 +949,15 @@ class ModelConfig:
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
@ -2308,7 +2324,7 @@ class LoRAConfig:
|
||||
# Setting the maximum rank to 512 should be able to satisfy the vast
|
||||
# majority of applications.
|
||||
possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
|
||||
possible_lora_extra_vocab_size = (0, 256, 512)
|
||||
possible_lora_extra_vocab_size = (256, 512)
|
||||
if self.max_lora_rank not in possible_max_ranks:
|
||||
raise ValueError(
|
||||
f"max_lora_rank ({self.max_lora_rank}) must be one of "
|
||||
|
@ -897,9 +897,22 @@ def initialize_model_parallel(
|
||||
get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
has_external_dp = False
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
if config.parallel_config.world_size != world_size:
|
||||
# detect external data parallelism.
|
||||
# dp in vllm means all dp instances need to run together.
|
||||
# if the world size does not match, it means this dp is external,
|
||||
# and the dp instances can run independently, e.g. in rlhf workflow
|
||||
# from https://github.com/volcengine/verl .
|
||||
# in that case, we treat the rest dimensions as if they are
|
||||
# data parallel, and create a dummy dp group that is not used.
|
||||
data_parallel_size = world_size // (pipeline_model_parallel_size *
|
||||
tensor_model_parallel_size)
|
||||
has_external_dp = True
|
||||
else:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
|
||||
# the layout order is: DP x PP x TP
|
||||
@ -940,6 +953,12 @@ def initialize_model_parallel(
|
||||
2).reshape(-1,
|
||||
data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if has_external_dp:
|
||||
# create a dummy dp group that is not used actually,
|
||||
# since this dp is external.
|
||||
# a dummy dp group means every rank is a group itself.
|
||||
# this way, no communication is needed, no memory is wasted.
|
||||
group_ranks = [[x] for x in range(world_size)]
|
||||
_DP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
|
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||
Tuple, Type, Union, cast, get_args)
|
||||
@ -1576,6 +1577,11 @@ class EngineArgs:
|
||||
#############################################################
|
||||
# Experimental Features - allow users to opt in.
|
||||
|
||||
# Signal Handlers requires running in main thread.
|
||||
if (threading.current_thread() != threading.main_thread()
|
||||
and _warn_or_fallback("Engine in background thread")):
|
||||
return False
|
||||
|
||||
# LoRA is supported on V1, but off by default for now.
|
||||
if self.enable_lora and _warn_or_fallback("LORA"):
|
||||
return False
|
||||
|
@ -29,6 +29,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.worker.model_runner_base import InputProcessingError
|
||||
|
||||
@ -428,6 +430,9 @@ def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
|
||||
ipc_path: str, disable_log_stats: bool,
|
||||
disable_log_requests: bool, engine_alive):
|
||||
try:
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
engine = MQLLMEngine.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
|
@ -82,6 +82,8 @@ from vllm.entrypoints.openai.serving_transcription import (
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
@ -221,6 +223,9 @@ async def build_async_engine_client_from_engine_args(
|
||||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
# The Process can raise an exception during startup, which may
|
||||
# not actually result in an exitcode being reported. As a result
|
||||
# we use a shared variable to communicate the information.
|
||||
|
@ -1,15 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401
|
||||
from vllm.lora.ops.triton_ops.lora_expand import lora_expand
|
||||
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
|
||||
from vllm.lora.ops.triton_ops.lora_shrink import lora_shrink
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_shrink",
|
||||
"lora_expand",
|
||||
"lora_shrink",
|
||||
"LoRAKernelMeta",
|
||||
]
|
||||
|
@ -1,188 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sn = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
|
||||
offset_k * xk_stride, ) # [BLOCK_K]
|
||||
else:
|
||||
tiled_a = tl.load(
|
||||
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
|
||||
mask=offset_k < K,
|
||||
other=0,
|
||||
) # [BLOCK_K]
|
||||
# N must be divisible by SPLIT_N
|
||||
split_n_length = tl.cdiv(N, SPLIT_N)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
# sliding to next row-block
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
pid_sn * split_n_length * lora_k_stride)
|
||||
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
|
||||
for n in range(0, split_n_length, BLOCK_N):
|
||||
current_n = n + offset_n
|
||||
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
|
||||
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
|
||||
< K)
|
||||
c_mask = current_n < split_n_length
|
||||
tiled_b = tl.load(
|
||||
b_ptr + current_n_c[:, None] * lora_k_stride +
|
||||
offset_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride,
|
||||
mask=c_mask,
|
||||
other=0.0)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
||||
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_K = triton.next_power_of_2(K)
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
batches = lora_indices_tensor.size(0)
|
||||
config = get_lora_op_configs("expand", batches, N)
|
||||
grid = lambda META: (
|
||||
META["SPLIT_N"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_K=BLOCK_K,
|
||||
EVEN_K=EVEN_K,
|
||||
ADD_INPUTS=ADD_INPUTS,
|
||||
CAST_TYPE=CAST_TYPE,
|
||||
**config,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def bgmv_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="bgmv_expand",
|
||||
op_func=_bgmv_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=bgmv_expand_fake,
|
||||
)
|
||||
bgmv_expand = torch.ops.vllm.bgmv_expand
|
||||
|
||||
except AttributeError:
|
||||
bgmv_expand = _bgmv_expand
|
@ -1,207 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_expand_slice_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sn = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
|
||||
offset_k * xk_stride, ) # [BLOCK_K]
|
||||
else:
|
||||
tiled_a = tl.load(
|
||||
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
|
||||
mask=offset_k < K,
|
||||
other=0,
|
||||
) # [BLOCK_K]
|
||||
# N must be divisible by SPLIT_N
|
||||
split_n_length = tl.cdiv(N, SPLIT_N)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
# sliding to next row-block
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
pid_sn * split_n_length * lora_k_stride)
|
||||
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
|
||||
slice_offset * cn_stride)
|
||||
|
||||
for n in range(0, split_n_length, BLOCK_N):
|
||||
current_n = n + offset_n
|
||||
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
|
||||
< K)
|
||||
c_mask = current_n < split_n_length
|
||||
tiled_b = tl.load(
|
||||
b_ptr + current_n[:, None] * lora_k_stride +
|
||||
offset_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
if ADD_INPUTS:
|
||||
# explicitly pass in other=None to tell triton that masked values
|
||||
# can be uninitialized. This is OK because the later tl.store
|
||||
# operation uses the same mask, eliminating the risk of garbage
|
||||
# values propagating
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride,
|
||||
mask=c_mask,
|
||||
other=None)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
||||
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False.
|
||||
"""
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_K = triton.next_power_of_2(K)
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
batches = lora_indices_tensor.size(0)
|
||||
|
||||
config = get_lora_op_configs("expand", batches, N)
|
||||
|
||||
grid = lambda META: (
|
||||
META["SPLIT_N"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_expand_slice_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_K=BLOCK_K,
|
||||
EVEN_K=EVEN_K,
|
||||
ADD_INPUTS=ADD_INPUTS,
|
||||
CAST_TYPE=CAST_TYPE,
|
||||
**config,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def bgmv_expand_slice_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="bgmv_expand_slice",
|
||||
op_func=_bgmv_expand_slice,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=bgmv_expand_slice_fake,
|
||||
)
|
||||
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
|
||||
|
||||
except AttributeError:
|
||||
bgmv_expand_slice = _bgmv_expand_slice
|
@ -1,168 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sk = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
|
||||
a_ptr = input_ptr + cur_batch * xm_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index
|
||||
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_K * SPLIT_K):
|
||||
current_k = k + offset_k
|
||||
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
|
||||
tiled_a = tl.load(
|
||||
a_ptr + current_k_c,
|
||||
mask=current_k < K,
|
||||
other=0.0,
|
||||
) # [BLOCK_K]
|
||||
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
|
||||
|
||||
tiled_b = tl.load(
|
||||
b_ptr + offset_n[:, None] * lora_k_stride +
|
||||
current_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
accumulator += tl.sum(tiled_a * tiled_b, 1)
|
||||
accumulator *= scaling
|
||||
offset_cn = tl.arange(0, BLOCK_N)
|
||||
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
|
||||
c_mask = offset_cn < N
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
batches = lora_indices_tensor.size(0)
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
BLOCK_N = triton.next_power_of_2(N)
|
||||
# First try to load optimal config from the file
|
||||
config = get_lora_op_configs("bgmv_shrink", batches, K)
|
||||
|
||||
grid = lambda META: (
|
||||
META["SPLIT_K"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_N=BLOCK_N,
|
||||
**config,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def bgmv_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="bgmv_shrink",
|
||||
op_func=_bgmv_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=bgmv_shrink_fake,
|
||||
)
|
||||
bgmv_shrink = torch.ops.vllm.bgmv_shrink
|
||||
|
||||
except AttributeError:
|
||||
bgmv_shrink = _bgmv_shrink
|
@ -18,7 +18,7 @@ from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _v1_expand_kernel(
|
||||
def _lora_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
@ -125,7 +125,7 @@ def _v1_expand_kernel(
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _v1_expand(
|
||||
def _lora_expand(
|
||||
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
lora_b_weights: List[
|
||||
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
|
||||
@ -216,7 +216,7 @@ def _v1_expand(
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_v1_expand_kernel[grid](
|
||||
_lora_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
@ -254,7 +254,7 @@ def _v1_expand(
|
||||
return
|
||||
|
||||
|
||||
def _v1_expand_fake(
|
||||
def _lora_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
@ -271,12 +271,12 @@ def _v1_expand_fake(
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="v1_expand",
|
||||
op_func=_v1_expand,
|
||||
op_name="lora_expand",
|
||||
op_func=_lora_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_v1_expand_fake,
|
||||
fake_impl=_lora_expand_fake,
|
||||
)
|
||||
v1_expand = torch.ops.vllm.v1_expand
|
||||
lora_expand = torch.ops.vllm.lora_expand
|
||||
|
||||
except AttributeError:
|
||||
v1_expand = _v1_expand
|
||||
lora_expand = _lora_expand
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
V1 LoRA kernels metadata preparation utilities.
|
||||
LoRA kernels metadata preparation utilities.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -10,7 +10,7 @@ import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class V1KernelMeta:
|
||||
class LoRAKernelMeta:
|
||||
token_lora_mapping: torch.Tensor
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor
|
||||
active_lora_ids: torch.Tensor
|
||||
@ -19,7 +19,7 @@ class V1KernelMeta:
|
||||
|
||||
@staticmethod
|
||||
def make(max_loras: int, max_num_tokens: int,
|
||||
device: Union[torch.device, str]) -> "V1KernelMeta":
|
||||
device: Union[torch.device, str]) -> "LoRAKernelMeta":
|
||||
|
||||
token_lora_mapping = torch.empty(max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
@ -47,7 +47,7 @@ class V1KernelMeta:
|
||||
lora_token_start_loc = torch.zeros(max_loras + 2,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
return V1KernelMeta(
|
||||
return LoRAKernelMeta(
|
||||
token_lora_mapping=token_lora_mapping,
|
||||
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
||||
active_lora_ids=active_lora_ids,
|
||||
@ -105,7 +105,7 @@ class V1KernelMeta:
|
||||
This function returns the kernel metadata required for the current
|
||||
forward pass execution of the kernel. The function returns all the
|
||||
metadata required by the kernel, in order, as a tuple, so it can be
|
||||
unpacked directly during the v1_shrink/v1_expand function call.
|
||||
unpacked directly during the lora_shrink/lora_expand function call.
|
||||
|
||||
Args:
|
||||
token_nums (int): Number of input tokens in the current forward
|
@ -18,15 +18,15 @@ from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
|
||||
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
|
||||
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
|
||||
lora_token_start_loc, lora_ids, scaling, input_d0_stride,
|
||||
input_d1_stride, lora_d0_stride, lora_d1_stride,
|
||||
lora_d2_stride, output_d0_stride, output_d1_stride,
|
||||
output_d2_stride, BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr):
|
||||
lora_token_start_loc, lora_ids, scaling,
|
||||
input_d0_stride, input_d1_stride, lora_d0_stride,
|
||||
lora_d1_stride, lora_d2_stride, output_d0_stride,
|
||||
output_d1_stride, output_d2_stride,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
|
||||
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
@ -96,7 +96,7 @@ def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _v1_shrink(
|
||||
def _lora_shrink(
|
||||
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
|
||||
lora_a_weights: List[
|
||||
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
|
||||
@ -174,7 +174,7 @@ def _v1_shrink(
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_v1_shrink_kernel[grid](
|
||||
_lora_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
@ -209,7 +209,7 @@ def _v1_shrink(
|
||||
return
|
||||
|
||||
|
||||
def _v1_shrink_fake(
|
||||
def _lora_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
@ -225,12 +225,12 @@ def _v1_shrink_fake(
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="v1_shrink",
|
||||
op_func=_v1_shrink,
|
||||
op_name="lora_shrink",
|
||||
op_func=_lora_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_v1_shrink_fake,
|
||||
fake_impl=_lora_shrink_fake,
|
||||
)
|
||||
v1_shrink = torch.ops.vllm.v1_shrink
|
||||
lora_shrink = torch.ops.vllm.lora_shrink
|
||||
|
||||
except AttributeError:
|
||||
v1_shrink = _v1_shrink
|
||||
lora_shrink = _lora_shrink
|
@ -1,249 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .kernel_utils import do_expand_kernel
|
||||
from .utils import _get_lora_b_ptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
slice_start_loc,
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride, # 1
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr, # 1
|
||||
output_d0_stride,
|
||||
output_d1_stride, # 1
|
||||
output_hs_ptr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr):
|
||||
"""
|
||||
|
||||
Similar to the 'sgmv_expand' operator, but with an added parameter
|
||||
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
||||
might be that in the future, we could implement a fusion operator to
|
||||
achieve the current functionality instead of having to call it multiple
|
||||
times.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
slice_id = tl.program_id(axis=2)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
# When the output dimensions of each slice are the same,cur_n=N, otherwise
|
||||
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
|
||||
# qkv linear.
|
||||
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M >= M:
|
||||
return
|
||||
if pid_n * BLOCK_N >= curr_N:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
|
||||
m_offset = tl.load(b_seq_start_loc + cur_batch)
|
||||
|
||||
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
|
||||
cta_m_offset = m_offset + (pid_m * BLOCK_M)
|
||||
offset_m = tl.arange(0, BLOCK_M)
|
||||
ram = cta_m_offset + tl.max_contiguous(
|
||||
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
|
||||
do_expand_kernel(
|
||||
pid_n,
|
||||
lora_index,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
curr_N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
SAME_STRIDE,
|
||||
SLICE_NUM,
|
||||
EVEN_K,
|
||||
CAST_TYPE,
|
||||
ADD_INPUTS,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (List[torch.Tensor]): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
offset_start (int, optional): Offset start for output_tensor.
|
||||
Defaults to 0.
|
||||
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||
output tensor. Defaults to False.
|
||||
"""
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
for weight in lora_b_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(1) == token_nums
|
||||
assert inputs.size(0) == len(lora_b_weights)
|
||||
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert output_tensor.is_contiguous()
|
||||
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
|
||||
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
|
||||
b_seq_start_loc.device)
|
||||
|
||||
# TODO tuning this config
|
||||
K = lora_b_weights[0].shape[-1] # K= rank
|
||||
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
|
||||
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
batches,
|
||||
len(lora_b_weights),
|
||||
)
|
||||
_sgmv_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
MAX_N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
slice_start_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
inputs.stride(2),
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
hidden_sizes_tensor,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
len(lora_b_weights),
|
||||
same_stride,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _sgmv_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="sgmv_expand",
|
||||
op_func=_sgmv_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_sgmv_expand_fake,
|
||||
)
|
||||
sgmv_expand = torch.ops.vllm.sgmv_expand
|
||||
|
||||
except AttributeError:
|
||||
sgmv_expand = _sgmv_expand
|
@ -1,224 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .kernel_utils import do_shrink_kernel
|
||||
from .utils import _get_lora_a_ptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr, #1-3
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
scaling,
|
||||
input_d0_stride,
|
||||
input_d1_stride, # 1
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride, # 1
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride, # 1
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr):
|
||||
"""
|
||||
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
||||
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
||||
introducing SPLIT-K can improve performance
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_mix = tl.program_id(axis=1)
|
||||
cur_batch = tl.program_id(axis=2)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
if SLICE_NUM == 1:
|
||||
slice_id: tl.constexpr = 0
|
||||
pid_sk = tl.program_id(axis=1)
|
||||
else:
|
||||
pid_mix = tl.program_id(axis=1)
|
||||
slice_id = pid_mix // SPLIT_K
|
||||
pid_sk = pid_mix % SPLIT_K
|
||||
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M >= M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
|
||||
m_offset = tl.load(b_seq_start_loc + cur_batch)
|
||||
|
||||
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
|
||||
cta_m_offset = m_offset + (pid_m * BLOCK_M)
|
||||
offset_m = tl.arange(0, BLOCK_M)
|
||||
ram = cta_m_offset + tl.max_contiguous(
|
||||
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
|
||||
|
||||
do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_index,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram,
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
SLICE_NUM)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (List[torch.Tensor]): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights[0].dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
for weight in lora_a_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
||||
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device)
|
||||
# TODO tuning this config
|
||||
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 32
|
||||
SPLIT_K = 8
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
SPLIT_K * len(lora_a_weights),
|
||||
batches,
|
||||
)
|
||||
_sgmv_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_strides_d0,
|
||||
lora_strides_d1,
|
||||
lora_strides_d2,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
output_tensor.stride(2),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
len(lora_a_weights),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def sgmv_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="sgmv_shrink",
|
||||
op_func=_sgmv_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=sgmv_shrink_fake,
|
||||
)
|
||||
sgmv_shrink = torch.ops.vllm.sgmv_shrink
|
||||
|
||||
except AttributeError:
|
||||
sgmv_shrink = _sgmv_shrink
|
@ -1,55 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _get_op_configs(op_type: str, batch: int, hidden_size: int):
|
||||
# TODO: add optimal configurations
|
||||
return None
|
||||
|
||||
|
||||
def _check_divisibility(hidden_size: int):
|
||||
# The bgmv_expand kernel requires that the hidden_size be divisible by
|
||||
# the number below.
|
||||
divisibility = [2, 4, 8, 16, 32, 64]
|
||||
divisibility.sort(reverse=True)
|
||||
for div in divisibility:
|
||||
if hidden_size % div == 0:
|
||||
return div
|
||||
# hidden_size is an odd number
|
||||
return 1
|
||||
|
||||
|
||||
def _get_default_config(op_type: str, batch: int, hidden_size: int):
|
||||
if op_type == "expand":
|
||||
return {
|
||||
"BLOCK_N": 256,
|
||||
"SPLIT_N": _check_divisibility(hidden_size),
|
||||
"num_warps": 8
|
||||
}
|
||||
else:
|
||||
return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8}
|
||||
|
||||
|
||||
def get_lora_op_configs(op_type: str, batch: int,
|
||||
hidden_size: int) -> Dict[str, int]:
|
||||
"""Inspired by `fused_moe_kernel`
|
||||
The return value will be a dictionary mapping an irregular grid of batch
|
||||
sizes and hidden_size to configurations of the bgmv-related kernel.
|
||||
NOTE: It currently only supports the default configuration. We plan to
|
||||
generate optimal configurations for different hardware in the future using
|
||||
scripts similar to `benchmark_moe.py`.
|
||||
"""
|
||||
config = _get_op_configs(op_type, batch, hidden_size)
|
||||
if not config:
|
||||
config = _get_default_config(op_type, batch, hidden_size)
|
||||
return config
|
||||
|
||||
|
||||
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand
|
||||
from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta
|
||||
from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink
|
||||
|
||||
__all__ = [
|
||||
"v1_expand",
|
||||
"v1_shrink",
|
||||
"V1KernelMeta",
|
||||
]
|
@ -10,20 +10,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as env
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
if env.VLLM_USE_V1:
|
||||
from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand,
|
||||
v1_shrink)
|
||||
else:
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops import sgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
|
||||
lora_shrink)
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
@ -32,57 +24,8 @@ if TYPE_CHECKING:
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
class V1KernelMixin:
|
||||
|
||||
def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int,
|
||||
max_batches: int, device: Union[torch.device, str]):
|
||||
self.token_mapping_v1_meta = V1KernelMeta.make(max_loras,
|
||||
max_num_batched_tokens,
|
||||
device=device)
|
||||
self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras,
|
||||
max_batches,
|
||||
device=device)
|
||||
|
||||
def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor):
|
||||
self.token_mapping_v1_meta.prepare_tensors(token_lora_indices)
|
||||
self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices)
|
||||
|
||||
def _v1_apply_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: Tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
):
|
||||
v1_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.token_mapping_v1_meta.meta_args(x.size(0)),
|
||||
scale,
|
||||
)
|
||||
|
||||
def _v1_apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: Tuple[torch.Tensor, ...],
|
||||
offset_start: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
v1_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.token_mapping_v1_meta.meta_args(x.size(0)),
|
||||
offset_start=offset_start,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperGPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
@ -96,9 +39,12 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
|
||||
self.max_loras = kwargs['max_loras']
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
self._v1_make_metadata(self.max_loras, max_num_batched_tokens,
|
||||
max_batches, device)
|
||||
self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras,
|
||||
max_num_batched_tokens,
|
||||
device=device)
|
||||
self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras,
|
||||
max_batches,
|
||||
device=device)
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
@ -110,83 +56,18 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
**kwargs):
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
self.is_prefill = mapping.is_prefill
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||
vocab_size, extra_vocab_size,
|
||||
long_lora_context)
|
||||
self._v1_prepare_metadata_tensors(self.token_lora_indices,
|
||||
self.sampler_indices)
|
||||
else:
|
||||
# Forward to base class update_metadata
|
||||
PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id,
|
||||
max_loras, vocab_size,
|
||||
extra_vocab_size,
|
||||
long_lora_context, **kwargs)
|
||||
|
||||
def _apply_shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: Tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
# Prepare cuda kernel metadata tensors
|
||||
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
|
||||
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
|
||||
|
||||
def _apply_shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||
|
||||
def _apply_expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: Tuple[torch.Tensor, ...],
|
||||
offset_start: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
offset_start=offset_start,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
def _apply_expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||
y_slice_size, add_inputs)
|
||||
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs):
|
||||
def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor,
|
||||
...], scale: float, **kwargs):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
@ -199,33 +80,24 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
y (torch.Tensor): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore
|
||||
else:
|
||||
if self.is_prefill:
|
||||
# NOTE fused kernel
|
||||
self._apply_shrink_prefill(
|
||||
y, # type: ignore
|
||||
lora_shrink(
|
||||
x,
|
||||
lora_a_stacked,
|
||||
scale)
|
||||
else:
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink_decode(y[slice_idx], x,
|
||||
lora_a_stacked[slice_idx], scale)
|
||||
y,
|
||||
*self.token_mapping_meta.meta_args(x.size(0)),
|
||||
scale,
|
||||
)
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
@ -244,7 +116,7 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
x (torch.Tensor): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
@ -259,37 +131,19 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
self._apply_bias(token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
# TODO (varun): Profile with add_inputs = False. i.e. move the
|
||||
# addition out of the kernel
|
||||
self._v1_apply_expand(
|
||||
y,
|
||||
x, # type: ignore
|
||||
lora_b_stacked,
|
||||
offset_start,
|
||||
add_inputs=True)
|
||||
else:
|
||||
assert x.ndim == 3
|
||||
assert x.size(0) == len(output_slices)
|
||||
num_tokens = x.size(1) # first dimension is the num slices
|
||||
|
||||
if self.is_prefill:
|
||||
# NOTE fused kernel
|
||||
self._apply_expand_prefill(
|
||||
y,
|
||||
x, # type: ignore
|
||||
lora_expand(
|
||||
x,
|
||||
lora_b_stacked,
|
||||
offset_start,
|
||||
add_inputs=True)
|
||||
else:
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand_decode(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_start,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
*self.token_mapping_meta.meta_args(num_tokens),
|
||||
offset_start=offset_start,
|
||||
add_inputs=True,
|
||||
)
|
||||
offset_start += output_slices[slice_idx]
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
@ -311,24 +165,14 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
self._v1_apply_expand(y,
|
||||
x.unsqueeze(dim=0), (lora_b_stacked, ),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs)
|
||||
else:
|
||||
if self.is_prefill:
|
||||
sgmv_expand(
|
||||
lora_expand(
|
||||
x.unsqueeze(dim=0),
|
||||
(lora_b_stacked, ),
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
*self.token_mapping_meta.meta_args(x.size(0)),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
else:
|
||||
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
|
||||
add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
@ -339,7 +183,7 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
@ -361,7 +205,7 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (Tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
buffer (Optional[torch.Tensor]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
@ -431,21 +275,11 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
|
||||
*self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale)
|
||||
lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
|
||||
*self.prompt_mapping_meta.meta_args(x.size(0)), scale)
|
||||
|
||||
v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
|
||||
lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
|
||||
y,
|
||||
*self.prompt_mapping_v1_meta.meta_args(buffer.size(0)),
|
||||
add_inputs=True)
|
||||
else:
|
||||
|
||||
# V0 LogitsProcessorWithLoRA always using bgmv.
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
*self.prompt_mapping_meta.meta_args(buffer.size(0)),
|
||||
add_inputs=True)
|
||||
y = y.view_as(y_org)
|
||||
|
@ -245,7 +245,6 @@ class MambaMixer2(CustomOp):
|
||||
assert num_heads % self.tp_size == 0, \
|
||||
"Tensor parallel world size must divide num heads."
|
||||
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
||||
(
|
||||
"If tensor parallel world size does not divide num_heads, "
|
||||
|
@ -10,7 +10,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
@ -140,6 +139,10 @@ def _fused_moe_gguf(
|
||||
qweight_type2: int,
|
||||
act,
|
||||
) -> torch.Tensor:
|
||||
# lazy import to avoid triggering triton import in CPU backend
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
moe_align_block_size)
|
||||
|
||||
out_hidden_states = torch.empty_like(x)
|
||||
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
|
||||
num_tokens, _ = x.shape
|
||||
|
@ -38,8 +38,6 @@ from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class BambaMLP(nn.Module):
|
||||
|
||||
|
@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .blip import BlipVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
@ -565,12 +565,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
|
||||
return Blip2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@ -578,12 +577,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
image_embeds = flatten_bn(image_embeds, concat=True)
|
||||
|
||||
return Blip2ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
|
@ -39,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
@ -972,12 +972,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
|
||||
return ChameleonImagePixelInputs(
|
||||
type="pixel_values",
|
||||
|
@ -478,7 +478,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
flatten_bn(images_spatial_crop, concat=True)))
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
|
@ -578,7 +578,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
|
@ -838,7 +838,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
@ -856,7 +856,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
assert isinstance(image_num_patches, (torch.Tensor, list))
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
|
@ -36,8 +36,6 @@ from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class JambaMoE(nn.Module):
|
||||
|
||||
|
@ -349,21 +349,18 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
List[b, Tensor(nb_frames, nb_channels, height, width)]
|
||||
}
|
||||
"""
|
||||
pixel_values = kwargs.pop("pixel_values_videos", None)
|
||||
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values_videos is None:
|
||||
return None
|
||||
|
||||
if not (is_list_of(pixel_values,
|
||||
(torch.Tensor)) # different shape videos
|
||||
or isinstance(pixel_values,
|
||||
torch.Tensor)): # same shape videos
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel_values_videos. "
|
||||
f"Got type: {type(pixel_values_videos)}")
|
||||
|
||||
return LlavaNextVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
data=pixel_values,
|
||||
data=pixel_values_videos,
|
||||
)
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
|
@ -574,10 +574,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values_videos is None:
|
||||
return None
|
||||
|
||||
if not (is_list_of(pixel_values_videos,
|
||||
torch.Tensor) # different shape videos
|
||||
or isinstance(pixel_values_videos,
|
||||
torch.Tensor)): # same shape videos
|
||||
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel_values_videos. "
|
||||
f"Got type: {type(pixel_values_videos)}")
|
||||
|
||||
|
@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -283,17 +283,19 @@ class Olmo2Model(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""
|
||||
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
||||
"""
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
# Get embeddings of input.
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
@ -337,7 +339,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -346,11 +348,13 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
@ -23,7 +23,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
@ -270,12 +270,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
|
||||
return PaliGemmaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@ -287,8 +286,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
image_embeds = flatten_bn(image_embeds, concat=True)
|
||||
|
||||
return PaliGemmaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
|
@ -711,7 +711,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
@ -722,13 +722,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return QwenImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
)
|
||||
|
||||
return None
|
||||
|
@ -105,6 +105,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
||||
# [Encoder-decoder]
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||
|
1031
vllm/model_executor/models/zamba2.py
Normal file
1031
vllm/model_executor/models/zamba2.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -62,9 +62,10 @@ class LoRAModelRunnerMixin:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
# Set is_prefill to True, so we always use the SGMV kernels.
|
||||
# For cuda platforms, we have specialized triton kernels, and
|
||||
# the cuda path ignores `is_prefill`.
|
||||
# Set is_prefill to True, so we always use the SGMV kernels on
|
||||
# non-cuda platforms.
|
||||
# On cuda platforms we use the same kernels for prefill and
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
|
Reference in New Issue
Block a user