mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Change current_device() to current_device_name() (#7600)
This PR fix a bug that in some place get_accelerator().current_device() are used instead of get_accelerator().current_device_name(). This would be mostly fine but on CPU this won't work `torch.empty(3, device=get_accelerator().current_device()` <-- won't work other than CUDA device `torch.empty(3, device=torch.device(get_accelerator().current_device()))` <-- works for GPU device, but won't work for CPU `torch.empty(3, device=torch.device(get_accelerator().current_device_name()))` <-- works for both GPU device and CPU `torch.empty(3, device=get_accelerator().current_device_name())` <-- this also works, but not as formal as the last one. This bug is exposed when I tried to run AutoTP training on Xeon server for debug purpose. --------- Signed-off-by: Guokai Ma <guokai.ma@gmail.com>
This commit is contained in:
@ -479,7 +479,7 @@ class DeepSpeedEngine(Module):
|
||||
dist.broadcast_object_list(object_list=_src_args,
|
||||
src=bcast_rank,
|
||||
group=bcast_group,
|
||||
device=get_accelerator().current_device())
|
||||
device=torch.device(get_accelerator().current_device_name()))
|
||||
# Rank 0 does not need to compare with itself
|
||||
is_equal = True
|
||||
else:
|
||||
@ -487,19 +487,19 @@ class DeepSpeedEngine(Module):
|
||||
dist.broadcast_object_list(object_list=_src_args,
|
||||
src=bcast_rank,
|
||||
group=bcast_group,
|
||||
device=get_accelerator().current_device())
|
||||
device=torch.device(get_accelerator().current_device_name()))
|
||||
|
||||
is_equal = compare_tensors_in_structures(args, _src_args[0])
|
||||
|
||||
equal_tensor = torch.tensor(is_equal,
|
||||
dtype=self.communication_data_type,
|
||||
device=get_accelerator().current_device())
|
||||
device=torch.device(get_accelerator().current_device_name()))
|
||||
dist.all_reduce(equal_tensor, group=bcast_group)
|
||||
assert torch.equal(
|
||||
equal_tensor,
|
||||
torch.tensor(groups.get_tensor_model_parallel_world_size(),
|
||||
dtype=self.communication_data_type,
|
||||
device=get_accelerator().current_device())
|
||||
device=torch.device(get_accelerator().current_device_name()))
|
||||
), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency."
|
||||
|
||||
bcast_rank = self.mpu.get_tensor_model_parallel_src_rank()
|
||||
|
@ -1165,8 +1165,8 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis
|
||||
return False
|
||||
for val1, val2 in zip(inputs1, inputs2):
|
||||
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
|
||||
val1 = val1.to(get_accelerator().current_device())
|
||||
val2 = val2.to(get_accelerator().current_device())
|
||||
val1 = val1.to(torch.device(get_accelerator().current_device_name()))
|
||||
val2 = val2.to(torch.device(get_accelerator().current_device_name()))
|
||||
if not torch.equal(val1, val2):
|
||||
return False
|
||||
elif val1 != val2:
|
||||
@ -1179,8 +1179,8 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis
|
||||
for key in inputs1:
|
||||
val1, val2 = inputs1[key], inputs2[key]
|
||||
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
|
||||
val1 = val1.to(get_accelerator().current_device())
|
||||
val2 = val2.to(get_accelerator().current_device())
|
||||
val1 = val1.to(torch.device(get_accelerator().current_device_name()))
|
||||
val2 = val2.to(torch.device(get_accelerator().current_device_name()))
|
||||
if not torch.equal(val1, val2):
|
||||
return False
|
||||
elif val1 != val2:
|
||||
|
@ -432,7 +432,7 @@ class PartitionedParameterCoordinator:
|
||||
free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module
|
||||
if not free_data:
|
||||
# wait for the computation to finish and launch as early as possible.
|
||||
empty_buffer = torch.empty(1, device=get_accelerator().current_device())
|
||||
empty_buffer = torch.empty(1, device=torch.device(get_accelerator().current_device_name()))
|
||||
|
||||
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
|
||||
param.ds_active_sub_modules.discard(submodule.ds_id)
|
||||
|
Reference in New Issue
Block a user