[XPU] Add xpu torch.compile support (#22609)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2025-08-27 13:33:27 +08:00
committed by GitHub
parent d272415e57
commit fce10dbed5
8 changed files with 36 additions and 11 deletions

View File

@ -31,6 +31,7 @@ docker run \
set -e
echo $ZE_AFFINITY_MASK
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
cd tests

View File

@ -190,8 +190,7 @@ class Attention(nn.Module, AttentionLayerBase):
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config

View File

@ -9,6 +9,7 @@ import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug("XPU platform does not support fix functionalization"
"pass currently.")
return
self.begin()
self.dump_graph(graph, "before_fix_functionalization")

View File

@ -335,3 +335,7 @@ class CpuPlatform(Platform):
return (cls.supports_v1(model_config)
and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC,
CpuArchEnum.ARM, CpuArchEnum.S390X))
@classmethod
def opaque_attention_op(cls) -> bool:
return True

View File

@ -442,6 +442,10 @@ class CudaPlatformBase(Platform):
def use_custom_allreduce(cls) -> bool:
return True
@classmethod
def opaque_attention_op(cls) -> bool:
return True
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"

View File

@ -509,6 +509,14 @@ class Platform:
"""
return False
@classmethod
def opaque_attention_op(cls) -> bool:
"""
Returns True if we register attention as one giant opaque custom op
on the current platform
"""
return False
@classmethod
def validate_request(
cls,

View File

@ -411,6 +411,10 @@ class RocmPlatform(Platform):
supported_archs = ['gfx94', 'gfx95']
return any(gfx in gcn_arch for gfx in supported_archs)
@classmethod
def opaque_attention_op(cls) -> bool:
return True
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(

View File

@ -90,21 +90,14 @@ class XPUPlatform(Platform):
if cache_config and cache_config.block_size is None:
cache_config.block_size = 64
# FIXME: Temporarily forcing eager mode
# remove after t.compile support stabilizes.
if (envs.VLLM_USE_V1 and model_config is not None
and not vllm_config.model_config.enforce_eager):
from vllm.config import CompilationLevel
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
!= CUDAGraphMode.NONE:
logger.info("[XPU] CUDA graph is not supported on XPU, "
"disabling cudagraphs.")
logger.info("[XPU] CUDA graph is not supported on XPU, disabling "
"cudagraphs. Fallback to cudagraph_mode=NONE")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# check and update parallel config
@ -182,3 +175,7 @@ class XPUPlatform(Platform):
"Intel Arc A770 have bfloat16 accuracy known issue. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")
@classmethod
def opaque_attention_op(cls) -> bool:
return True