Files
DeepSpeed/accelerator/npu_accelerator.py
Welsper aa539c6dd5 fix npu device_id AttributeError issue (#7560)
## Environment
```
torch        2.7.1
torch_npu    2.7.1rc1
deepspeed    0.17.3
```
## Issue
An `AttributeError` is raised when `init_process_group` on NPU device
since deepspeed v0.17.3.
The issue is similar to
https://github.com/deepspeedai/DeepSpeed/pull/7488.

Trace:
```
Traceback (most recent call last):
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/cli/sft.py", line 10, in <module>
    sft_main()
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/train/sft.py", line 331, in sft_main
    return SwiftSft(args).main()
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/train/sft.py", line 27, in __init__
    super().__init__(args)
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/base.py", line 19, in __init__
    self.args = self._parse_args(args)
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/base.py", line 31, in _parse_args
    args, remaining_argv = parse_args(self.args_class, args)
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/utils/utils.py", line 152, in parse_args
    args, remaining_args = parser.parse_args_into_dataclasses(argv, return_remaining_strings=True)
  File "/home/welsper/.local/lib/python3.10/site-packages/transformers/hf_argparser.py", line 358, in parse_args_into_dataclasses
    obj = dtype(**inputs)
  File "<string>", line 325, in __init__
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/argument/train_args.py", line 175, in __post_init__
    self.training_args = TrainerFactory.get_training_args(self)
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/trainer_factory.py", line 70, in get_training_args
    return training_args_cls(**args_dict)
  File "<string>", line 167, in __init__
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/arguments.py", line 152, in __post_init__
    super().__post_init__()
  File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/arguments.py", line 133, in __post_init__
    super().__post_init__()
  File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 1803, in __post_init__
    self.device
  File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 2332, in device
    return self._setup_devices
  File "/home/welsper/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 74, in __get__
    cached = self.fget(obj)
  File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 2259, in _setup_devices
    self.distributed_state = PartialState(**accelerator_state_kwargs)
  File "/home/welsper/.local/lib/python3.10/site-packages/accelerate/state.py", line 216, in __init__
    dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
  File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/comm.py", line 854, in init_distributed
    cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
  File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/torch.py", line 120, in __init__
    self.init_process_group(backend, timeout, init_method, rank, world_size)
  File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/torch.py", line 163, in init_process_group
    torch.distributed.init_process_group(backend, **kwargs)
  File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
  File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 95, in wrapper
    func_return = func(*args, **kwargs)
  File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1717, in init_process_group
    default_pg, _ = _new_process_group_helper(
  File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1831, in _new_process_group_helper
    if device_id is not None and (device_id.index is None or device_id.type == "cpu"):
AttributeError: 'device' object has no attribute 'index'
```

## Fix
Switch `torch.npu.device(device_index)` to `torch.device('npu',
device_index)`.

Now:

d40a0f5de8/accelerator/npu_accelerator.py (L47-L48)

After fix:
```python
 def device(self, device_index=None): 
     return torch.device('npu', device_index) 
```

Signed-off-by: welsper <welsper@qq.com>
Co-authored-by: welsper <xinyuyang@cmbchina.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
2025-09-17 15:46:33 +08:00

300 lines
8.8 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import importlib
import inspect
from .abstract_accelerator import DeepSpeedAccelerator
# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch.npu
except ImportError:
pass
class NPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
super().__init__()
self._name = 'npu'
self._communication_backend_name = 'hccl'
self._compile_backend = "inductor"
# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
self.class_dict = None
def is_synchronized_device(self):
return False
def use_host_timers(self):
return self.is_synchronized_device()
def resolves_data_dependency(self):
return self.is_synchronized_device()
def handles_memory_backpressure(self):
return self.is_synchronized_device()
# Device APIs
def device_name(self, device_index=None):
if device_index is None:
return 'npu'
return 'npu:{}'.format(device_index)
def device(self, device_index=None):
return torch.device('npu', device_index)
def set_device(self, device_index):
torch.npu.set_device(device_index)
def current_device(self):
return torch.npu.current_device()
def current_device_name(self):
return 'npu:{}'.format(torch.npu.current_device())
def device_count(self):
return torch.npu.device_count()
def synchronize(self, device_index=None):
return torch.npu.synchronize(device_index)
# RNG APIs
def random(self):
return torch.random
def set_rng_state(self, new_state, device_index=None):
if device_index is None:
return torch.npu.set_rng_state(new_state)
return torch.npu.set_rng_state(new_state, device_index)
def get_rng_state(self, device_index=None):
if device_index is None:
return torch.npu.get_rng_state()
return torch.npu.get_rng_state(device_index)
def manual_seed(self, seed):
return torch.npu.manual_seed(seed)
def manual_seed_all(self, seed):
return torch.npu.manual_seed_all(seed)
def initial_seed(self):
return torch.npu.initial_seed()
def default_generator(self, device_index):
return torch.npu.default_generators[device_index]
# Streams/Events
@property
def Stream(self):
return torch.npu.Stream
def stream(self, stream):
return torch.npu.stream(stream)
def current_stream(self, device_index=None):
return torch.npu.current_stream(device_index)
def default_stream(self, device_index=None):
return torch.npu.default_stream(device_index)
@property
def Event(self):
return torch.npu.Event
# Memory management
def empty_cache(self):
return torch.npu.empty_cache()
def memory_allocated(self, device_index=None):
return torch.npu.memory_allocated(device_index)
def max_memory_allocated(self, device_index=None):
return torch.npu.max_memory_allocated(device_index)
def reset_max_memory_allocated(self, device_index=None):
return torch.npu.reset_max_memory_allocated(device_index)
def memory_cached(self, device_index=None):
return torch.npu.memory_cached(device_index)
def max_memory_cached(self, device_index=None):
return torch.npu.max_memory_cached(device_index)
def reset_max_memory_cached(self, device_index=None):
return torch.npu.reset_max_memory_cached(device_index)
def memory_stats(self, device_index=None):
if hasattr(torch.npu, 'memory_stats'):
return torch.npu.memory_stats(device_index)
def reset_peak_memory_stats(self, device_index=None):
if hasattr(torch.npu, 'reset_peak_memory_stats'):
return torch.npu.reset_peak_memory_stats(device_index)
def memory_reserved(self, device_index=None):
if hasattr(torch.npu, 'memory_reserved'):
return torch.npu.memory_reserved(device_index)
def max_memory_reserved(self, device_index=None):
if hasattr(torch.npu, 'max_memory_reserved'):
return torch.npu.max_memory_reserved(device_index)
def total_memory(self, device_index=None):
return torch.npu.get_device_properties(device_index).total_memory
def available_memory(self, device_index=None):
return self.total_memory(device_index) - self.memory_allocated(device_index)
# Data types
def is_bf16_supported(self):
return torch.npu.is_bf16_supported()
def is_fp16_supported(self):
return True
def supported_dtypes(self):
return [torch.float, torch.half, torch.bfloat16]
# Misc
def amp(self):
if hasattr(torch.npu, 'amp'):
return torch.npu.amp
return None
def is_available(self):
return torch.npu.is_available()
def range_push(self, msg):
return
def range_pop(self):
return
def lazy_call(self, callback):
return torch.npu._lazy_call(callback)
def communication_backend_name(self):
return self._communication_backend_name
def is_triton_supported(self):
return False
# Graph operations
def create_graph(self):
return None
def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()
def replay_graph(self, graph):
return
# Tensor operations
@property
def BFloat16Tensor(self):
return torch.npu.BFloat16Tensor
@property
def ByteTensor(self):
return torch.npu.ByteTensor
@property
def DoubleTensor(self):
return torch.npu.DoubleTensor
@property
def FloatTensor(self):
return torch.npu.FloatTensor
@property
def HalfTensor(self):
return torch.npu.HalfTensor
@property
def IntTensor(self):
return torch.npu.IntTensor
@property
def LongTensor(self):
return torch.npu.LongTensor
def pin_memory(self, tensor, align_bytes=1):
return tensor.pin_memory()
def is_pinned(self, tensor):
return tensor.is_pinned()
def on_accelerator(self, tensor):
device_str = str(tensor.device)
if device_str.startswith('npu:'):
return True
else:
return False
def op_builder_dir(self):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
return "op_builder.npu"
except ImportError:
return "deepspeed.ops.op_builder.npu"
def _lazy_init_class_dict(self):
if self.class_dict:
return
op_builder_module = importlib.import_module(self.op_builder_dir())
# get op builder class from op_builder/npu/__init__.py
self.class_dict = {}
for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
self.class_dict[class_name] = class_obj
# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, class_name):
builder_class = self.get_op_builder(class_name)
return None if builder_class is None else builder_class()
# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]
else:
return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension
def export_envs(self):
return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']
def visible_devices_envs(self):
return ['ASCEND_RT_VISIBLE_DEVICES']
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))
def get_compile_backend(self):
return self._compile_backend
def set_compile_backend(self, backend):
supported_backends = torch._dynamo.list_backends(exclude_tags=())
if backend in supported_backends:
self._compile_backend = backend
else:
raise ValueError(
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")