[xpu]feat: support multi-lora on xpu (#20616)

Signed-off-by: yan <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2025-07-08 22:07:10 +08:00
committed by GitHub
parent b942c094e3
commit a4c23314c0
5 changed files with 28 additions and 4 deletions

View File

@ -13,6 +13,7 @@ import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@ -283,6 +284,7 @@ try:
op_func=_lora_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_expand = torch.ops.vllm.lora_expand

View File

@ -13,6 +13,7 @@ import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@ -237,6 +238,7 @@ try:
op_func=_lora_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_shrink = torch.ops.vllm.lora_shrink

View File

@ -35,7 +35,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
lora_strides_d1.append(lora_a_weight.stride(1))
lora_strides_d2.append(lora_a_weight.stride(2))
if len(lora_a_weights) > 1:
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
lora_ptr_tensor = torch.tensor(tensor_ptrs,
device=device,
dtype=torch.uint64)
else:
lora_ptr_tensor = lora_a_weights[0]
@ -89,8 +91,12 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
if len(lora_weights) > 1:
# note these are device tensors
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
lora_ptr_tensor = torch.tensor(tensor_ptrs,
device=device,
dtype=torch.uint64)
slice_start_tensor = torch.tensor(slice_offset_lst,
device=device,
dtype=torch.uint64)
else:
slice_start_tensor = slice_offset_lst[0]
lora_ptr_tensor = lora_b_weight[0]

View File

@ -27,6 +27,7 @@ from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
if TYPE_CHECKING:
@ -513,7 +514,9 @@ def deserialize_tensorizer_model(model: nn.Module,
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=torch.device("cuda", torch.cuda.current_device()),
device=f'xpu:{torch.xpu.current_device()}'
if current_platform.is_xpu() else
f'cuda:{torch.cuda.current_device()}',
**tensorizer_args.deserialization_kwargs) as deserializer:
deserializer.load_into_module(model)
end = time.perf_counter()

View File

@ -58,6 +58,10 @@ class XPUPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return torch.xpu.get_device_name(device_id)
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id)
@ -78,6 +82,13 @@ class XPUPlatform(Platform):
if cache_config and cache_config.block_size is None:
cache_config.block_size = 64
# FIXME: Temporarily forcing eager mode
# remove after t.compile support stabilizes.
if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
and not vllm_config.model_config.enforce_eager):
from vllm.config import CompilationLevel
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.