mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
avoid setting device_id to init_process_group
(#7542)
In some usecases such as vllm, we need to new distributed group not only on gpu, but also on cpu, if we set `device_id` here, it will prevent us from new distributed group on cpu: [L230](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py#L230) . This PR fixes this bug. --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@ -147,17 +147,14 @@ class TorchBackend(Backend):
|
|||||||
|
|
||||||
def init_process_group(self, backend, timeout, init_method, rank, world_size):
|
def init_process_group(self, backend, timeout, init_method, rank, world_size):
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
kwargs = dict(
|
kwargs = dict(timeout=timeout, init_method=init_method, rank=rank, world_size=world_size)
|
||||||
timeout=timeout,
|
|
||||||
init_method=init_method,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. device_id arg was added in torch==2.3
|
# 1. device_id arg was added in torch==2.3
|
||||||
# 2. setting device_id leads to hanging in 2.6.0<torch<2.7.1 https://github.com/pytorch/pytorch/issues/153960
|
# 2. setting device_id leads to hanging in 2.6.0<torch<2.7.1 https://github.com/pytorch/pytorch/issues/153960
|
||||||
if 'device_id' in inspect.signature(torch.distributed.init_process_group).parameters and not (
|
# 3. device_id works and is needed for `cuda`, other accelerators may have issues at the moment. Therefore only do it for the `cuda` accelerator.
|
||||||
version.parse("2.6.0") < version.parse(torch.__version__) < version.parse("2.7.1")):
|
if ('device_id' in inspect.signature(torch.distributed.init_process_group).parameters
|
||||||
|
and not (version.parse("2.6.0") < version.parse(torch.__version__) < version.parse("2.7.1"))
|
||||||
|
and get_accelerator().device_name() == 'cuda'):
|
||||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||||
kwargs.update(device_id=get_accelerator().device(local_rank))
|
kwargs.update(device_id=get_accelerator().device(local_rank))
|
||||||
torch.distributed.init_process_group(backend, **kwargs)
|
torch.distributed.init_process_group(backend, **kwargs)
|
||||||
|
Reference in New Issue
Block a user