[Kernel] Register punica ops directly (#10522)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-11-22 01:18:11 +08:00
committed by GitHub
parent da7e702c6f
commit 2385b60d83
7 changed files with 157 additions and 24 deletions

View File

@ -6,12 +6,13 @@ maximum ranks.
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
# Enable custom op register
import vllm.lora.ops.bgmv_expand
import vllm.lora.ops.bgmv_expand_slice
import vllm.lora.ops.bgmv_shrink
import vllm.lora.ops.sgmv_expand
import vllm.lora.ops.sgmv_expand_slice
import vllm.lora.ops.sgmv_shrink # noqa: F401
from vllm.platforms import current_platform
from .utils import (generate_data, generate_data_for_expand_nslices,
@ -37,6 +38,16 @@ def assert_close(a, b):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
# Unlike test_punica_sizes.py, we directly utilize custom op for
# testing, which verifies the correct registration of these ops.
bgmv_expand = torch.ops.vllm.bgmv_expand
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
bgmv_shrink = torch.ops.vllm.bgmv_shrink
sgmv_expand = torch.ops.vllm.sgmv_expand
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
sgmv_shrink = torch.ops.vllm.sgmv_shrink
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)

View File

@ -9,6 +9,8 @@ import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@ -162,9 +164,24 @@ def _bgmv_expand(
return
def bgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
return
try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="bgmv_expand",
op_func=_bgmv_expand,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_fake,
)
bgmv_expand = torch.ops.vllm.bgmv_expand
except AttributeError:
bgmv_expand = _bgmv_expand

View File

@ -9,6 +9,8 @@ import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@ -179,9 +181,26 @@ def _bgmv_expand_slice(
return
def bgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
return
try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="bgmv_expand_slice",
op_func=_bgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_slice_fake,
)
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice

View File

@ -9,6 +9,8 @@ import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs
@ -142,9 +144,24 @@ def _bgmv_shrink(
return
def bgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
return
try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="bgmv_shrink",
op_func=_bgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=bgmv_shrink_fake,
)
bgmv_shrink = torch.ops.vllm.bgmv_shrink
except AttributeError:
bgmv_shrink = _bgmv_shrink

View File

@ -9,6 +9,8 @@ import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
@triton.jit
def _sgmv_expand_kernel(
@ -196,9 +198,30 @@ def _sgmv_expand(
return
def sgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
return
try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="sgmv_expand",
op_func=_sgmv_expand,
mutates_args=["output_tensor"],
fake_impl=sgmv_expand_fake,
)
sgmv_expand = torch.ops.vllm.sgmv_expand
except AttributeError:
sgmv_expand = _sgmv_expand

View File

@ -9,6 +9,8 @@ import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
@triton.jit
def _sgmv_expand_slice_kernel(
@ -209,9 +211,31 @@ def _sgmv_expand_slice(
return
def sgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
return
try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="sgmv_expand_slice",
op_func=_sgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=sgmv_expand_slice_fake,
)
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice

View File

@ -9,6 +9,8 @@ import torch
import triton
import triton.language as tl
from vllm.utils import direct_register_custom_op
@triton.jit
def _sgmv_shrink_kernel(
@ -190,9 +192,29 @@ def _sgmv_shrink(
return
def sgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
return
try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="sgmv_shrink",
op_func=_sgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=sgmv_shrink_fake,
)
sgmv_shrink = torch.ops.vllm.sgmv_shrink
except AttributeError:
sgmv_shrink = _sgmv_shrink