mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
## Environment
```
torch 2.7.1
torch_npu 2.7.1rc1
deepspeed 0.17.3
```
## Issue
An `AttributeError` is raised when `init_process_group` on NPU device
since deepspeed v0.17.3.
The issue is similar to
https://github.com/deepspeedai/DeepSpeed/pull/7488.
Trace:
```
Traceback (most recent call last):
File "/home/welsper/.local/lib/python3.10/site-packages/swift/cli/sft.py", line 10, in <module>
sft_main()
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/train/sft.py", line 331, in sft_main
return SwiftSft(args).main()
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/train/sft.py", line 27, in __init__
super().__init__(args)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/base.py", line 19, in __init__
self.args = self._parse_args(args)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/base.py", line 31, in _parse_args
args, remaining_argv = parse_args(self.args_class, args)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/utils/utils.py", line 152, in parse_args
args, remaining_args = parser.parse_args_into_dataclasses(argv, return_remaining_strings=True)
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/hf_argparser.py", line 358, in parse_args_into_dataclasses
obj = dtype(**inputs)
File "<string>", line 325, in __init__
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/argument/train_args.py", line 175, in __post_init__
self.training_args = TrainerFactory.get_training_args(self)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/trainer_factory.py", line 70, in get_training_args
return training_args_cls(**args_dict)
File "<string>", line 167, in __init__
File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/arguments.py", line 152, in __post_init__
super().__post_init__()
File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/arguments.py", line 133, in __post_init__
super().__post_init__()
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 1803, in __post_init__
self.device
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 2332, in device
return self._setup_devices
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 74, in __get__
cached = self.fget(obj)
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 2259, in _setup_devices
self.distributed_state = PartialState(**accelerator_state_kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/accelerate/state.py", line 216, in __init__
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/comm.py", line 854, in init_distributed
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/torch.py", line 120, in __init__
self.init_process_group(backend, timeout, init_method, rank, world_size)
File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/torch.py", line 163, in init_process_group
torch.distributed.init_process_group(backend, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
return func(*args, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 95, in wrapper
func_return = func(*args, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1717, in init_process_group
default_pg, _ = _new_process_group_helper(
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1831, in _new_process_group_helper
if device_id is not None and (device_id.index is None or device_id.type == "cpu"):
AttributeError: 'device' object has no attribute 'index'
```
## Fix
Switch `torch.npu.device(device_index)` to `torch.device('npu',
device_index)`.
Now:
d40a0f5de8/accelerator/npu_accelerator.py (L47-L48)
After fix:
```python
def device(self, device_index=None):
return torch.device('npu', device_index)
```
Signed-off-by: welsper <welsper@qq.com>
Co-authored-by: welsper <xinyuyang@cmbchina.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>