From 66c70312f2de26a41a05460804889d12008c6954 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Mon, 29 Sep 2025 01:19:49 +0800 Subject: [PATCH] Change current_device() to current_device_name() (#7600) This PR fix a bug that in some place get_accelerator().current_device() are used instead of get_accelerator().current_device_name(). This would be mostly fine but on CPU this won't work `torch.empty(3, device=get_accelerator().current_device()` <-- won't work other than CUDA device `torch.empty(3, device=torch.device(get_accelerator().current_device()))` <-- works for GPU device, but won't work for CPU `torch.empty(3, device=torch.device(get_accelerator().current_device_name()))` <-- works for both GPU device and CPU `torch.empty(3, device=get_accelerator().current_device_name())` <-- this also works, but not as formal as the last one. This bug is exposed when I tried to run AutoTP training on Xeon server for debug purpose. --------- Signed-off-by: Guokai Ma --- deepspeed/runtime/engine.py | 8 ++++---- deepspeed/runtime/utils.py | 8 ++++---- deepspeed/runtime/zero/partitioned_param_coordinator.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 363da7a76..3b05a9f11 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -479,7 +479,7 @@ class DeepSpeedEngine(Module): dist.broadcast_object_list(object_list=_src_args, src=bcast_rank, group=bcast_group, - device=get_accelerator().current_device()) + device=torch.device(get_accelerator().current_device_name())) # Rank 0 does not need to compare with itself is_equal = True else: @@ -487,19 +487,19 @@ class DeepSpeedEngine(Module): dist.broadcast_object_list(object_list=_src_args, src=bcast_rank, group=bcast_group, - device=get_accelerator().current_device()) + device=torch.device(get_accelerator().current_device_name())) is_equal = compare_tensors_in_structures(args, _src_args[0]) equal_tensor = torch.tensor(is_equal, dtype=self.communication_data_type, - device=get_accelerator().current_device()) + device=torch.device(get_accelerator().current_device_name())) dist.all_reduce(equal_tensor, group=bcast_group) assert torch.equal( equal_tensor, torch.tensor(groups.get_tensor_model_parallel_world_size(), dtype=self.communication_data_type, - device=get_accelerator().current_device()) + device=torch.device(get_accelerator().current_device_name())) ), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency." bcast_rank = self.mpu.get_tensor_model_parallel_src_rank() diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2e1181f0b..d101ecece 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1165,8 +1165,8 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis return False for val1, val2 in zip(inputs1, inputs2): if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): - val1 = val1.to(get_accelerator().current_device()) - val2 = val2.to(get_accelerator().current_device()) + val1 = val1.to(torch.device(get_accelerator().current_device_name())) + val2 = val2.to(torch.device(get_accelerator().current_device_name())) if not torch.equal(val1, val2): return False elif val1 != val2: @@ -1179,8 +1179,8 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis for key in inputs1: val1, val2 = inputs1[key], inputs2[key] if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): - val1 = val1.to(get_accelerator().current_device()) - val2 = val2.to(get_accelerator().current_device()) + val1 = val1.to(torch.device(get_accelerator().current_device_name())) + val2 = val2.to(torch.device(get_accelerator().current_device_name())) if not torch.equal(val1, val2): return False elif val1 != val2: diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 8d82f5a1d..a754bee63 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -432,7 +432,7 @@ class PartitionedParameterCoordinator: free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module if not free_data: # wait for the computation to finish and launch as early as possible. - empty_buffer = torch.empty(1, device=get_accelerator().current_device()) + empty_buffer = torch.empty(1, device=torch.device(get_accelerator().current_device_name())) for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): param.ds_active_sub_modules.discard(submodule.ds_id)