[XPU][CI] enhance xpu test support (#20652)

Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
Co-authored-by: zhenwei-intel <zhenweiliu@habana.ai>
This commit is contained in:
Liangliang Ma
2025-07-10 00:53:09 +08:00
committed by GitHub
parent eb58f5953d
commit a3e4e85ece
5 changed files with 18 additions and 12 deletions

View File

@ -759,7 +759,8 @@ class VllmRunner:
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
- `seed`: Set to `0` instead of `None` for test reproducibility.
- `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `block_size`: To reduce memory usage, set default to `64` if on XPU
devices, otherwise default to `16`.
- `enable_chunked_prefill`: Set to `False` instead of `None` for
test reproducibility.
- `enforce_eager`: Set to `False` to test CUDA graph.
@ -777,7 +778,7 @@ class VllmRunner:
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
block_size: int = 16 if not torch.xpu.is_available() else 64,
enable_chunked_prefill: Optional[bool] = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = False,

View File

@ -53,3 +53,6 @@ class XpuCommunicator(DeviceCommunicatorBase):
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)

View File

@ -240,6 +240,8 @@ class GroupCoordinator:
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
elif current_platform.is_out_of_tree():
self.device = torch.device(
f"{current_platform.device_name}:{local_rank}")
@ -1317,13 +1319,13 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.
Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
@ -1352,7 +1354,7 @@ def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
Args:
pg: The process group to analyze
Returns:
int: The total number of nodes
"""

View File

@ -91,6 +91,7 @@ class XPUPlatform(Platform):
# FIXME: Temporarily forcing eager mode
# remove after t.compile support stabilizes.
if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
and not vllm_config.model_config.enforce_eager):
from vllm.config import CompilationLevel
@ -111,9 +112,6 @@ class XPUPlatform(Platform):
"mode.")
model_config.enforce_eager = True
if vllm_config.device_config is not None:
assert vllm_config.device_config.device_type == "xpu"
# check and update parallel config
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
@ -131,8 +129,10 @@ class XPUPlatform(Platform):
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
logger.warning(
"Please use spawn as start method if you want to use mp.")
elif parallel_config.distributed_executor_backend != "ray" and \
parallel_config.distributed_executor_backend != "uni":
elif (parallel_config.distributed_executor_backend != "ray"
and parallel_config.distributed_executor_backend != "uni"
and parallel_config.distributed_executor_backend
!= "external_launcher"):
logger.warning(
"%s is not supported on XPU, fallback to ray distributed"
" executor backend.",

View File

@ -27,7 +27,7 @@ class XPUModelRunner(GPUModelRunner):
self.cascade_attn_enabled = False
def _init_device_properties(self) -> None:
pass
self.num_sms = None
def _sync_device(self) -> None:
torch.xpu.synchronize()