[XPU] support data parallel for MoE models on XPU (#22887)
Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
@ -7,8 +7,13 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def all_reduce(self, input_) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
|
@ -655,6 +655,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
forward_native = forward_tpu
|
||||
elif current_platform.is_cpu():
|
||||
forward_native = forward_cpu
|
||||
elif current_platform.is_xpu():
|
||||
forward_native = forward_xpu
|
||||
else:
|
||||
forward_native = forward_cuda
|
||||
|
||||
|
Reference in New Issue
Block a user