[XPU] Support XCCL on deepspeed side (#7299)

XCCL will be used for XPU device on Pytorch-2.8, with this support will
remove torch-ccl on XPU device, and we will also reserve the old path
for torch-CCL enable.

---------

Signed-off-by: yisheng <yi.sheng@intel.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
This commit is contained in:
YiSheng5
2025-05-23 00:31:26 +08:00
committed by GitHub
parent 0e3209a16b
commit bdba8231bc
2 changed files with 46 additions and 12 deletions

View File

@ -136,6 +136,21 @@ def get_accelerator():
accelerator_name = "xpu"
except ImportError as e:
pass
if accelerator_name is None:
try:
import torch
# torch.xpu will be supported in upstream pytorch-2.8.
# Currently we can run on xpu device only using pytorch,
# also reserve the old path using ipex when the torch version is old.
if hasattr(torch, 'xpu'):
if torch.cuda.device_count() == 0: #ignore-cuda
if torch.xpu.device_count() > 0 and torch.xpu.is_available():
accelerator_name = "xpu"
else:
pass
except ImportError as e:
pass
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore

View File

@ -5,19 +5,32 @@
import torch
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
import functools
import importlib
import inspect
try:
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
oneccl_imported_p = True
except ImportError as e:
oneccl_imported_p = False
try:
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
ipex_imported_p = True
except ImportError as e:
ipex_imported_p = False
class XPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'xpu'
self._communication_backend_name = 'ccl'
if oneccl_imported_p:
self._communication_backend_name = 'ccl'
else:
# changed to xccl if not using torch-CCL on XPU device
self._communication_backend_name = 'xccl'
self._compile_backend = "inductor"
self.aligned_tensors = []
self.class_dict = None
@ -26,11 +39,14 @@ class XPU_Accelerator(DeepSpeedAccelerator):
return False
def use_host_timers(self):
# WA XPU event will be consolidated in 2.6
if ipex.__version__ < '2.6':
return True
else:
if not ipex_imported_p:
return self.is_synchronized_device()
else:
# WA XPU event will be consolidated in 2.6
if ipex.__version__ < '2.6':
return True
else:
return self.is_synchronized_device()
def resolves_data_dependency(self):
return self.is_synchronized_device()
@ -290,10 +306,13 @@ class XPU_Accelerator(DeepSpeedAccelerator):
return self.class_dict['NotImplementedBuilder']
def build_extension(self):
try:
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
except ImportError:
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
if ipex_imported_p:
try:
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
except ImportError:
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
else:
from torch.utils.cpp_extension import DpcppBuildExtension
return DpcppBuildExtension
def export_envs(self):