mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[xpu]feat: support multi-lora on xpu (#20616)
Signed-off-by: yan <yan.ma@intel.com>
This commit is contained in:
@ -13,6 +13,7 @@ import triton.language as tl
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@ -283,6 +284,7 @@ try:
|
||||
op_func=_lora_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_expand_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
lora_expand = torch.ops.vllm.lora_expand
|
||||
|
||||
|
@ -13,6 +13,7 @@ import triton.language as tl
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@ -237,6 +238,7 @@ try:
|
||||
op_func=_lora_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_shrink_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
lora_shrink = torch.ops.vllm.lora_shrink
|
||||
|
||||
|
@ -35,7 +35,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
|
||||
lora_strides_d1.append(lora_a_weight.stride(1))
|
||||
lora_strides_d2.append(lora_a_weight.stride(2))
|
||||
if len(lora_a_weights) > 1:
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs,
|
||||
device=device,
|
||||
dtype=torch.uint64)
|
||||
else:
|
||||
lora_ptr_tensor = lora_a_weights[0]
|
||||
|
||||
@ -89,8 +91,12 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
|
||||
|
||||
if len(lora_weights) > 1:
|
||||
# note these are device tensors
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs,
|
||||
device=device,
|
||||
dtype=torch.uint64)
|
||||
slice_start_tensor = torch.tensor(slice_offset_lst,
|
||||
device=device,
|
||||
dtype=torch.uint64)
|
||||
else:
|
||||
slice_start_tensor = slice_offset_lst[0]
|
||||
lora_ptr_tensor = lora_b_weight[0]
|
||||
|
@ -27,6 +27,7 @@ from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -513,7 +514,9 @@ def deserialize_tensorizer_model(model: nn.Module,
|
||||
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=torch.device("cuda", torch.cuda.current_device()),
|
||||
device=f'xpu:{torch.xpu.current_device()}'
|
||||
if current_platform.is_xpu() else
|
||||
f'cuda:{torch.cuda.current_device()}',
|
||||
**tensorizer_args.deserialization_kwargs) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
|
@ -58,6 +58,10 @@ class XPUPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
@ -78,6 +82,13 @@ class XPUPlatform(Platform):
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 64
|
||||
|
||||
# FIXME: Temporarily forcing eager mode
|
||||
# remove after t.compile support stabilizes.
|
||||
if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
|
||||
and not vllm_config.model_config.enforce_eager):
|
||||
from vllm.config import CompilationLevel
|
||||
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
|
||||
|
||||
# Instances created using VllmConfig() typically have model_config as
|
||||
# None by default. The modification involves adding a check to prevent
|
||||
# potential null exceptions check and update model config.
|
||||
|
Reference in New Issue
Block a user