mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
## Environment
```
torch 2.7.1
torch_npu 2.7.1rc1
deepspeed 0.17.3
```
## Issue
An `AttributeError` is raised when `init_process_group` on NPU device
since deepspeed v0.17.3.
The issue is similar to
https://github.com/deepspeedai/DeepSpeed/pull/7488.
Trace:
```
Traceback (most recent call last):
File "/home/welsper/.local/lib/python3.10/site-packages/swift/cli/sft.py", line 10, in <module>
sft_main()
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/train/sft.py", line 331, in sft_main
return SwiftSft(args).main()
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/train/sft.py", line 27, in __init__
super().__init__(args)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/base.py", line 19, in __init__
self.args = self._parse_args(args)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/base.py", line 31, in _parse_args
args, remaining_argv = parse_args(self.args_class, args)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/utils/utils.py", line 152, in parse_args
args, remaining_args = parser.parse_args_into_dataclasses(argv, return_remaining_strings=True)
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/hf_argparser.py", line 358, in parse_args_into_dataclasses
obj = dtype(**inputs)
File "<string>", line 325, in __init__
File "/home/welsper/.local/lib/python3.10/site-packages/swift/llm/argument/train_args.py", line 175, in __post_init__
self.training_args = TrainerFactory.get_training_args(self)
File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/trainer_factory.py", line 70, in get_training_args
return training_args_cls(**args_dict)
File "<string>", line 167, in __init__
File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/arguments.py", line 152, in __post_init__
super().__post_init__()
File "/home/welsper/.local/lib/python3.10/site-packages/swift/trainers/arguments.py", line 133, in __post_init__
super().__post_init__()
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 1803, in __post_init__
self.device
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 2332, in device
return self._setup_devices
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 74, in __get__
cached = self.fget(obj)
File "/home/welsper/.local/lib/python3.10/site-packages/transformers/training_args.py", line 2259, in _setup_devices
self.distributed_state = PartialState(**accelerator_state_kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/accelerate/state.py", line 216, in __init__
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/comm.py", line 854, in init_distributed
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/torch.py", line 120, in __init__
self.init_process_group(backend, timeout, init_method, rank, world_size)
File "/home/welsper/.local/lib/python3.10/site-packages/deepspeed/comm/torch.py", line 163, in init_process_group
torch.distributed.init_process_group(backend, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
return func(*args, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 95, in wrapper
func_return = func(*args, **kwargs)
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1717, in init_process_group
default_pg, _ = _new_process_group_helper(
File "/home/welsper/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1831, in _new_process_group_helper
if device_id is not None and (device_id.index is None or device_id.type == "cpu"):
AttributeError: 'device' object has no attribute 'index'
```
## Fix
Switch `torch.npu.device(device_index)` to `torch.device('npu',
device_index)`.
Now:
d40a0f5de8/accelerator/npu_accelerator.py (L47-L48)
After fix:
```python
def device(self, device_index=None):
return torch.device('npu', device_index)
```
Signed-off-by: welsper <welsper@qq.com>
Co-authored-by: welsper <xinyuyang@cmbchina.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
300 lines
8.8 KiB
Python
300 lines
8.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
import importlib
|
|
import inspect
|
|
|
|
from .abstract_accelerator import DeepSpeedAccelerator
|
|
# During setup stage torch may not be installed, pass on no torch will
|
|
# allow op builder related API to be executed.
|
|
try:
|
|
import torch.npu
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class NPU_Accelerator(DeepSpeedAccelerator):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._name = 'npu'
|
|
self._communication_backend_name = 'hccl'
|
|
self._compile_backend = "inductor"
|
|
# dict that holds class name <--> class type mapping i.e.
|
|
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
|
|
# this dict will be filled at init stage
|
|
self.class_dict = None
|
|
|
|
def is_synchronized_device(self):
|
|
return False
|
|
|
|
def use_host_timers(self):
|
|
return self.is_synchronized_device()
|
|
|
|
def resolves_data_dependency(self):
|
|
return self.is_synchronized_device()
|
|
|
|
def handles_memory_backpressure(self):
|
|
return self.is_synchronized_device()
|
|
|
|
# Device APIs
|
|
def device_name(self, device_index=None):
|
|
if device_index is None:
|
|
return 'npu'
|
|
return 'npu:{}'.format(device_index)
|
|
|
|
def device(self, device_index=None):
|
|
return torch.device('npu', device_index)
|
|
|
|
def set_device(self, device_index):
|
|
torch.npu.set_device(device_index)
|
|
|
|
def current_device(self):
|
|
return torch.npu.current_device()
|
|
|
|
def current_device_name(self):
|
|
return 'npu:{}'.format(torch.npu.current_device())
|
|
|
|
def device_count(self):
|
|
return torch.npu.device_count()
|
|
|
|
def synchronize(self, device_index=None):
|
|
return torch.npu.synchronize(device_index)
|
|
|
|
# RNG APIs
|
|
def random(self):
|
|
return torch.random
|
|
|
|
def set_rng_state(self, new_state, device_index=None):
|
|
if device_index is None:
|
|
return torch.npu.set_rng_state(new_state)
|
|
|
|
return torch.npu.set_rng_state(new_state, device_index)
|
|
|
|
def get_rng_state(self, device_index=None):
|
|
if device_index is None:
|
|
return torch.npu.get_rng_state()
|
|
|
|
return torch.npu.get_rng_state(device_index)
|
|
|
|
def manual_seed(self, seed):
|
|
return torch.npu.manual_seed(seed)
|
|
|
|
def manual_seed_all(self, seed):
|
|
return torch.npu.manual_seed_all(seed)
|
|
|
|
def initial_seed(self):
|
|
return torch.npu.initial_seed()
|
|
|
|
def default_generator(self, device_index):
|
|
return torch.npu.default_generators[device_index]
|
|
|
|
# Streams/Events
|
|
@property
|
|
def Stream(self):
|
|
return torch.npu.Stream
|
|
|
|
def stream(self, stream):
|
|
return torch.npu.stream(stream)
|
|
|
|
def current_stream(self, device_index=None):
|
|
return torch.npu.current_stream(device_index)
|
|
|
|
def default_stream(self, device_index=None):
|
|
return torch.npu.default_stream(device_index)
|
|
|
|
@property
|
|
def Event(self):
|
|
return torch.npu.Event
|
|
|
|
# Memory management
|
|
def empty_cache(self):
|
|
return torch.npu.empty_cache()
|
|
|
|
def memory_allocated(self, device_index=None):
|
|
return torch.npu.memory_allocated(device_index)
|
|
|
|
def max_memory_allocated(self, device_index=None):
|
|
return torch.npu.max_memory_allocated(device_index)
|
|
|
|
def reset_max_memory_allocated(self, device_index=None):
|
|
return torch.npu.reset_max_memory_allocated(device_index)
|
|
|
|
def memory_cached(self, device_index=None):
|
|
return torch.npu.memory_cached(device_index)
|
|
|
|
def max_memory_cached(self, device_index=None):
|
|
return torch.npu.max_memory_cached(device_index)
|
|
|
|
def reset_max_memory_cached(self, device_index=None):
|
|
return torch.npu.reset_max_memory_cached(device_index)
|
|
|
|
def memory_stats(self, device_index=None):
|
|
if hasattr(torch.npu, 'memory_stats'):
|
|
return torch.npu.memory_stats(device_index)
|
|
|
|
def reset_peak_memory_stats(self, device_index=None):
|
|
if hasattr(torch.npu, 'reset_peak_memory_stats'):
|
|
return torch.npu.reset_peak_memory_stats(device_index)
|
|
|
|
def memory_reserved(self, device_index=None):
|
|
if hasattr(torch.npu, 'memory_reserved'):
|
|
return torch.npu.memory_reserved(device_index)
|
|
|
|
def max_memory_reserved(self, device_index=None):
|
|
if hasattr(torch.npu, 'max_memory_reserved'):
|
|
return torch.npu.max_memory_reserved(device_index)
|
|
|
|
def total_memory(self, device_index=None):
|
|
return torch.npu.get_device_properties(device_index).total_memory
|
|
|
|
def available_memory(self, device_index=None):
|
|
return self.total_memory(device_index) - self.memory_allocated(device_index)
|
|
|
|
# Data types
|
|
def is_bf16_supported(self):
|
|
return torch.npu.is_bf16_supported()
|
|
|
|
def is_fp16_supported(self):
|
|
return True
|
|
|
|
def supported_dtypes(self):
|
|
return [torch.float, torch.half, torch.bfloat16]
|
|
|
|
# Misc
|
|
def amp(self):
|
|
if hasattr(torch.npu, 'amp'):
|
|
return torch.npu.amp
|
|
return None
|
|
|
|
def is_available(self):
|
|
return torch.npu.is_available()
|
|
|
|
def range_push(self, msg):
|
|
return
|
|
|
|
def range_pop(self):
|
|
return
|
|
|
|
def lazy_call(self, callback):
|
|
return torch.npu._lazy_call(callback)
|
|
|
|
def communication_backend_name(self):
|
|
return self._communication_backend_name
|
|
|
|
def is_triton_supported(self):
|
|
return False
|
|
|
|
# Graph operations
|
|
def create_graph(self):
|
|
return None
|
|
|
|
def capture_to_graph(self, graph, pool=None, stream=None):
|
|
from deepspeed.runtime.utils import noop_context
|
|
return noop_context()
|
|
|
|
def replay_graph(self, graph):
|
|
return
|
|
|
|
# Tensor operations
|
|
|
|
@property
|
|
def BFloat16Tensor(self):
|
|
return torch.npu.BFloat16Tensor
|
|
|
|
@property
|
|
def ByteTensor(self):
|
|
return torch.npu.ByteTensor
|
|
|
|
@property
|
|
def DoubleTensor(self):
|
|
return torch.npu.DoubleTensor
|
|
|
|
@property
|
|
def FloatTensor(self):
|
|
return torch.npu.FloatTensor
|
|
|
|
@property
|
|
def HalfTensor(self):
|
|
return torch.npu.HalfTensor
|
|
|
|
@property
|
|
def IntTensor(self):
|
|
return torch.npu.IntTensor
|
|
|
|
@property
|
|
def LongTensor(self):
|
|
return torch.npu.LongTensor
|
|
|
|
def pin_memory(self, tensor, align_bytes=1):
|
|
return tensor.pin_memory()
|
|
|
|
def is_pinned(self, tensor):
|
|
return tensor.is_pinned()
|
|
|
|
def on_accelerator(self, tensor):
|
|
device_str = str(tensor.device)
|
|
if device_str.startswith('npu:'):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def op_builder_dir(self):
|
|
try:
|
|
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
|
# if successful this also means we're doing a local install and not JIT compile path
|
|
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
|
return "op_builder.npu"
|
|
except ImportError:
|
|
return "deepspeed.ops.op_builder.npu"
|
|
|
|
def _lazy_init_class_dict(self):
|
|
if self.class_dict:
|
|
return
|
|
|
|
op_builder_module = importlib.import_module(self.op_builder_dir())
|
|
|
|
# get op builder class from op_builder/npu/__init__.py
|
|
self.class_dict = {}
|
|
for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
|
|
self.class_dict[class_name] = class_obj
|
|
|
|
# create an instance of op builder and return, name specified by class_name
|
|
def create_op_builder(self, class_name):
|
|
builder_class = self.get_op_builder(class_name)
|
|
return None if builder_class is None else builder_class()
|
|
|
|
# return an op builder class, name specified by class_name
|
|
def get_op_builder(self, class_name):
|
|
self._lazy_init_class_dict()
|
|
if class_name in self.class_dict:
|
|
return self.class_dict[class_name]
|
|
else:
|
|
return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None
|
|
|
|
def build_extension(self):
|
|
from torch.utils.cpp_extension import BuildExtension
|
|
return BuildExtension
|
|
|
|
def export_envs(self):
|
|
return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']
|
|
|
|
def visible_devices_envs(self):
|
|
return ['ASCEND_RT_VISIBLE_DEVICES']
|
|
|
|
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
|
for env in self.visible_devices_envs():
|
|
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
|
|
|
def get_compile_backend(self):
|
|
return self._compile_backend
|
|
|
|
def set_compile_backend(self, backend):
|
|
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
|
if backend in supported_backends:
|
|
self._compile_backend = backend
|
|
else:
|
|
raise ValueError(
|
|
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")
|