Change current_device() to current_device_name() (#7600)

This PR fix a bug that in some place get_accelerator().current_device()
are used instead of get_accelerator().current_device_name(). This would
be mostly fine but on CPU this won't work

`torch.empty(3, device=get_accelerator().current_device()` <-- won't
work other than CUDA device
`torch.empty(3,
device=torch.device(get_accelerator().current_device()))` <-- works for
GPU device, but won't work for CPU
`torch.empty(3,
device=torch.device(get_accelerator().current_device_name()))` <-- works
for both GPU device and CPU
`torch.empty(3, device=get_accelerator().current_device_name())` <--
this also works, but not as formal as the last one.

This bug is exposed when I tried to run AutoTP training on Xeon server
for debug purpose.

---------

Signed-off-by: Guokai Ma <guokai.ma@gmail.com>
This commit is contained in:
Ma, Guokai
2025-09-29 01:19:49 +08:00
committed by GitHub
parent 91d14527b6
commit 66c70312f2
3 changed files with 9 additions and 9 deletions

View File

@ -479,7 +479,7 @@ class DeepSpeedEngine(Module):
dist.broadcast_object_list(object_list=_src_args,
src=bcast_rank,
group=bcast_group,
device=get_accelerator().current_device())
device=torch.device(get_accelerator().current_device_name()))
# Rank 0 does not need to compare with itself
is_equal = True
else:
@ -487,19 +487,19 @@ class DeepSpeedEngine(Module):
dist.broadcast_object_list(object_list=_src_args,
src=bcast_rank,
group=bcast_group,
device=get_accelerator().current_device())
device=torch.device(get_accelerator().current_device_name()))
is_equal = compare_tensors_in_structures(args, _src_args[0])
equal_tensor = torch.tensor(is_equal,
dtype=self.communication_data_type,
device=get_accelerator().current_device())
device=torch.device(get_accelerator().current_device_name()))
dist.all_reduce(equal_tensor, group=bcast_group)
assert torch.equal(
equal_tensor,
torch.tensor(groups.get_tensor_model_parallel_world_size(),
dtype=self.communication_data_type,
device=get_accelerator().current_device())
device=torch.device(get_accelerator().current_device_name()))
), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency."
bcast_rank = self.mpu.get_tensor_model_parallel_src_rank()

View File

@ -1165,8 +1165,8 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis
return False
for val1, val2 in zip(inputs1, inputs2):
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
val1 = val1.to(get_accelerator().current_device())
val2 = val2.to(get_accelerator().current_device())
val1 = val1.to(torch.device(get_accelerator().current_device_name()))
val2 = val2.to(torch.device(get_accelerator().current_device_name()))
if not torch.equal(val1, val2):
return False
elif val1 != val2:
@ -1179,8 +1179,8 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis
for key in inputs1:
val1, val2 = inputs1[key], inputs2[key]
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
val1 = val1.to(get_accelerator().current_device())
val2 = val2.to(get_accelerator().current_device())
val1 = val1.to(torch.device(get_accelerator().current_device_name()))
val2 = val2.to(torch.device(get_accelerator().current_device_name()))
if not torch.equal(val1, val2):
return False
elif val1 != val2:

View File

@ -432,7 +432,7 @@ class PartitionedParameterCoordinator:
free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module
if not free_data:
# wait for the computation to finish and launch as early as possible.
empty_buffer = torch.empty(1, device=get_accelerator().current_device())
empty_buffer = torch.empty(1, device=torch.device(get_accelerator().current_device_name()))
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.ds_id)