Files
DeepSpeed/accelerator/real_accelerator.py
RyanInnerpeace 4b7cae7bea [NPU] Add NPU support for unit test (#4569)
Unit tests would fail or skip when device=npu, and we definitely want to
test all these wonderful features by official unit tests.
Here comes the commit to add NPU support for unit test. P.S. see what we
have already done #4567.


**What I do in this commit**
1. Just add npu logic branch 
feat: Add npu support for skip_on_arch in tests/unit/util.py
feat: Add npu support for skip_on_cuda in tests/unit/util.py
feat: Add npu support for tests/unit/common.py

2. Set_device of accelerator before deepspeed.init_distributed in
tests/unit/common.py
It would be friendlier and easier for other device like npu, if we can
set_device of accelerator before init_distributed. Plus, setting device
param before init sounds more reasonable.

3. Solve the problem of calling get_accelerator().random().fork_rng with
non-cuda device
Function `train_cifar()` in `tests/unit/alexnet_model.py` calls
`get_accelerator().random().fork_rng` without passing `device_type`
explicitly. Unfortunately, `torch.random.fork_rng()` has default value
setting `device_type=cuda` and non-cuda devices would fail to run. So my
solution is explicitly passing
`device_type=get_accelerator().device_name()`, and either cuda or
non-cuda devices would perform correctly.

---------

Co-authored-by: ryan <ruanzhixiang1@huawei.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-11-13 20:36:12 +00:00

216 lines
8.4 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
try:
# Importing logger currently requires that torch is installed, hence the try...except
# TODO: Remove logger dependency on torch.
from deepspeed.utils import logger as accel_logger
except ImportError as e:
accel_logger = None
try:
from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
except ImportError as e:
dsa1 = None
try:
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa2
except ImportError as e:
dsa2 = None
SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']
ds_accelerator = None
def _validate_accelerator(accel_obj):
# because abstract_accelerator has different path during
# build time (accelerator.abstract_accelerator)
# and run time (deepspeed.accelerator.abstract_accelerator)
# and extension would import the
# run time abstract_accelerator/DeepSpeedAccelerator as its base
# class, so we need to compare accel_obj with both base class.
# if accel_obj is instance of DeepSpeedAccelerator in one of
# accelerator.abstractor_accelerator
# or deepspeed.accelerator.abstract_accelerator, consider accel_obj
# is a conforming object
if not ((dsa1 is not None and isinstance(accel_obj, dsa1)) or (dsa2 is not None and isinstance(accel_obj, dsa2))):
raise AssertionError(f"{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator")
# TODO: turn off is_available test since this breaks tests
# assert accel_obj.is_available(), \
# f'{accel_obj.__class__.__name__} accelerator fails is_available() test'
def is_current_accelerator_supported():
return get_accelerator() in SUPPORTED_ACCELERATOR_LIST
def get_accelerator():
global ds_accelerator
if ds_accelerator is not None:
return ds_accelerator
accelerator_name = None
ds_set_method = None
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
if "DS_ACCELERATOR" in os.environ.keys():
accelerator_name = os.environ["DS_ACCELERATOR"]
if accelerator_name == "xpu":
try:
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401 # type: ignore
except ImportError as e:
raise ValueError(
f"XPU_Accelerator requires intel_extension_for_deepspeed, which is not installed on this system.")
elif accelerator_name == "cpu":
try:
import intel_extension_for_pytorch # noqa: F401 # type: ignore
except ImportError as e:
raise ValueError(
f"CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
elif accelerator_name == "npu":
try:
import torch_npu # noqa: F401 # type: ignore
except ImportError as e:
raise ValueError(f"NPU_Accelerator requires torch_npu, which is not installed on this system.")
pass
elif accelerator_name == "mps":
try:
import torch.mps
# should use torch.mps.is_available() if it exists someday but this is used as proxy
torch.mps.current_allocated_memory()
except (RuntimeError, ImportError) as e:
raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.")
elif is_current_accelerator_supported():
raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
f'Value "{accelerator_name}" is not supported')
ds_set_method = "override"
# 2. If no override, detect which accelerator to use automatically
if accelerator_name is None:
# We need a way to choose among different accelerator types.
# Currently we detect which accelerator extension is installed
# in the environment and use it if the installing answer is True.
# An alternative might be detect whether CUDA device is installed on
# the system but this comes with two pitfalls:
# 1. the system may not have torch pre-installed, so
# get_accelerator().is_available() may not work.
# 2. Some scenario like install on login node (without CUDA device)
# and run on compute node (with CUDA device) may cause mismatch
# between installation time and runtime.
try:
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811 # type: ignore
accelerator_name = "xpu"
except ImportError as e:
pass
if accelerator_name is None:
try:
import intel_extension_for_pytorch # noqa: F401,F811 # type: ignore
accelerator_name = "cpu"
except ImportError as e:
pass
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore
accelerator_name = "npu"
except ImportError as e:
pass
if accelerator_name is None:
try:
import torch.mps
# should use torch.mps.is_available() if it exists someday but this is used as proxy
torch.mps.current_allocated_memory()
accelerator_name = "mps"
except (RuntimeError, ImportError) as e:
pass
if accelerator_name is None:
accelerator_name = "cuda"
ds_set_method = "auto detect"
# 3. Set ds_accelerator accordingly
if accelerator_name == "cuda":
from .cuda_accelerator import CUDA_Accelerator
ds_accelerator = CUDA_Accelerator()
elif accelerator_name == "cpu":
from .cpu_accelerator import CPU_Accelerator
ds_accelerator = CPU_Accelerator()
elif accelerator_name == "xpu":
# XPU_Accelerator is already imported in detection stage
ds_accelerator = XPU_Accelerator()
elif accelerator_name == "npu":
from .npu_accelerator import NPU_Accelerator
ds_accelerator = NPU_Accelerator()
elif accelerator_name == "mps":
from .mps_accelerator import MPS_Accelerator
ds_accelerator = MPS_Accelerator()
_validate_accelerator(ds_accelerator)
if accel_logger is not None:
accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})")
return ds_accelerator
def set_accelerator(accel_obj):
global ds_accelerator
_validate_accelerator(accel_obj)
if accel_logger is not None:
accel_logger.info(f"Setting ds_accelerator to {accel_obj._name} (model specified)")
ds_accelerator = accel_obj
"""
-----------[code] test_get.py -----------
from deepspeed.accelerator import get_accelerator
my_accelerator = get_accelerator()
logger.info(f'{my_accelerator._name=}')
logger.info(f'{my_accelerator._communication_backend=}')
logger.info(f'{my_accelerator.HalfTensor().device=}')
logger.info(f'{my_accelerator.total_memory()=}')
-----------[code] test_get.py -----------
---[output] python test_get.py---------
my_accelerator.name()='cuda'
my_accelerator.communication_backend='nccl'
my_accelerator.HalfTensor().device=device(type='cuda', index=0)
my_accelerator.total_memory()=34089730048
---[output] python test_get.py---------
**************************************************************************
-----------[code] test_set.py -----------
from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
cu_accel = CUDA_Accelerator()
logger.info(f'{id(cu_accel)=}')
from deepspeed.accelerator import set_accelerator, get_accelerator
set_accelerator(cu_accel)
my_accelerator = get_accelerator()
logger.info(f'{id(my_accelerator)=}')
logger.info(f'{my_accelerator._name=}')
logger.info(f'{my_accelerator._communication_backend=}')
logger.info(f'{my_accelerator.HalfTensor().device=}')
logger.info(f'{my_accelerator.total_memory()=}')
-----------[code] test_set.py -----------
---[output] python test_set.py---------
id(cu_accel)=139648165478304
my_accelerator=<deepspeed.accelerator.cuda_accelerator.CUDA_Accelerator object at 0x7f025f4bffa0>
my_accelerator.name='cuda'
my_accelerator.communication_backend='nccl'
my_accelerator.HalfTensor().device=device(type='cuda', index=0)
my_accelerator.total_memory()=34089730048
---[output] python test_set.py---------
"""