mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Make op builder detection adapt to accelerator change (#5206)
This is an WIP PR that make op builder detection adapt to accelerator change. This is followup of https://github.com/microsoft/DeepSpeed/issues/5173 Currently, DeepSpeed generate `installed_ops` and `compatible_ops` at setup time. If the system change to a different accelerator at DeepSpeed launch time, these two list would contain incorrect information. This PR intend to solve this problem with more flexity ops detection. * For `installed_ops`, DeepSpeed should disable all installed ops if accelerator detected at setup time is different from launch time. * For `compatible_ops`, DeepSpeed should refresh the list for each launch to avoid impact of accelerator change. In the first step, nv-inference workflow is temporary change to emulate the scenario that the system is setup with CPU_Accelerator, then launch with CUDA_Accelerator. And CPU_Accelerator is modified to make Intel Extension for PyTorch and oneCCL binding for PyTorch not mandatory. Starting from here we can reconstruct installed_ops and compatible_ops to follow the design above. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
23
.github/workflows/cpu-inference.yml
vendored
23
.github/workflows/cpu-inference.yml
vendored
@ -47,42 +47,26 @@ jobs:
|
||||
- name: Detect instruction sets on instance
|
||||
run: |
|
||||
lscpu
|
||||
pip install cmake
|
||||
git clone https://github.com/intel/intel-extension-for-pytorch
|
||||
cd intel-extension-for-pytorch/tests/cpu/isa
|
||||
cmake .
|
||||
make
|
||||
./cpu_features
|
||||
|
||||
- name: Install numactl
|
||||
run: |
|
||||
sudo apt-get install -y numactl
|
||||
|
||||
- name: Install oneCCL Bindings for PyTorch
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install torch
|
||||
python -m pip install intel_extension_for_pytorch
|
||||
# the curl line is for troubleshooting
|
||||
curl -L https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
|
||||
python -m pip install oneccl_bind_pt --index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
|
||||
pip install py-cpuinfo
|
||||
# check installed version
|
||||
pip list |grep \\\<torch\\\>
|
||||
pip list |grep intel-extension-for-pytorch
|
||||
pip list |grep oneccl-bind-pt
|
||||
|
||||
- name: Install oneCCL
|
||||
run: |
|
||||
pip install cmake
|
||||
git clone https://github.com/oneapi-src/oneCCL
|
||||
cd oneCCL
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
make
|
||||
make install
|
||||
#source ./_install/env/setvars.sh
|
||||
# test whether oneCCL is correctly installed
|
||||
#mpirun -n 2 ./examples/benchmark/benchmark
|
||||
make -j install
|
||||
|
||||
- name: Install transformers
|
||||
run: |
|
||||
@ -103,7 +87,6 @@ jobs:
|
||||
source oneCCL/build/_install/env/setvars.sh
|
||||
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6
|
||||
# check whether the environment is properly setup
|
||||
python -c "import torch;import intel_extension_for_pytorch as ipex;import oneccl_bindings_for_pytorch;print('done')"
|
||||
python -c "import deepspeed;from deepspeed.accelerator import get_accelerator;print(get_accelerator().device_name());print(get_accelerator().is_available())"
|
||||
|
||||
- name: Unit tests
|
||||
|
4
.github/workflows/cpu-torch-latest.yml
vendored
4
.github/workflows/cpu-torch-latest.yml
vendored
@ -27,6 +27,10 @@ jobs:
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
||||
- name: Install system packages
|
||||
run: |
|
||||
sudo apt-get install -y numactl pdsh
|
||||
|
||||
- name: Install pytorch
|
||||
run: |
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
|
5
.github/workflows/nv-inference.yml
vendored
5
.github/workflows/nv-inference.yml
vendored
@ -46,7 +46,8 @@ jobs:
|
||||
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
pip install .[dev,1bit,autotuning,inf,triton]
|
||||
DS_ACCELERATOR=cpu pip install .[dev,1bit,autotuning,inf]
|
||||
#pip install .[dev,1bit,autotuning,inf,triton]
|
||||
ds_report
|
||||
|
||||
- name: Python environment
|
||||
@ -60,3 +61,5 @@ jobs:
|
||||
#pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"
|
||||
pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="2.1" --cuda_ver="11.8"
|
||||
pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"
|
||||
# run ds_report again to check updated op list
|
||||
ds_report
|
||||
|
2
.github/workflows/nv-pre-compile-ops.yml
vendored
2
.github/workflows/nv-pre-compile-ops.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
|
||||
- name: Compile DeepSpeed Ops
|
||||
run: |
|
||||
DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
|
||||
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
|
||||
- name: DS Report
|
||||
run: |
|
||||
ds_report
|
||||
|
@ -4,9 +4,14 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
|
||||
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
|
||||
import psutil
|
||||
from .abstract_accelerator import DeepSpeedAccelerator
|
||||
|
||||
try:
|
||||
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
|
||||
oneccl_imported_p = True
|
||||
except ImportError as e:
|
||||
oneccl_imported_p = False
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@ -15,8 +20,17 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
||||
|
||||
def __init__(self):
|
||||
self._name = 'cpu'
|
||||
self._communication_backend_name = 'ccl'
|
||||
self.max_mem = psutil.Process().memory_info().rss
|
||||
if oneccl_imported_p:
|
||||
self._communication_backend_name = 'ccl'
|
||||
else:
|
||||
# fallback to gloo if oneccl_binding_for_pytorch is not installed
|
||||
self._communication_backend_name = 'gloo'
|
||||
try:
|
||||
import psutil
|
||||
mem = psutil.Process().memory_info().rss
|
||||
self.max_mem = mem
|
||||
except ImportError as e:
|
||||
self.max_mem = 0
|
||||
|
||||
def is_synchronized_device(self):
|
||||
return True
|
||||
@ -115,12 +129,14 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
||||
return
|
||||
|
||||
def get_rss(self):
|
||||
import psutil
|
||||
mem = psutil.Process().memory_info().rss
|
||||
if mem > self.max_mem:
|
||||
self.max_mem = mem
|
||||
return mem
|
||||
|
||||
def reset_rss(self):
|
||||
import psutil
|
||||
mem = psutil.Process().memory_info().rss
|
||||
self.max_mem = mem
|
||||
return mem
|
||||
@ -166,9 +182,11 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
||||
return self.max_mem
|
||||
|
||||
def total_memory(self, device_index=None):
|
||||
import psutil
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
def available_memory(self, device_index=None):
|
||||
import psutil
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
# Misc
|
||||
|
@ -73,11 +73,7 @@ def get_accelerator():
|
||||
f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
|
||||
)
|
||||
elif accelerator_name == "cpu":
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401 # type: ignore
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
|
||||
pass
|
||||
elif accelerator_name == "npu":
|
||||
try:
|
||||
import torch_npu # noqa: F401 # type: ignore
|
||||
@ -154,7 +150,23 @@ def get_accelerator():
|
||||
except ImportError as e:
|
||||
pass
|
||||
if accelerator_name is None:
|
||||
accelerator_name = "cuda"
|
||||
# borrow this log from PR#5084
|
||||
try:
|
||||
import torch
|
||||
|
||||
# Determine if we are on a GPU or x86 CPU with torch.
|
||||
if torch.cuda.is_available(): #ignore-cuda
|
||||
accelerator_name = "cuda"
|
||||
else:
|
||||
if accel_logger is not None:
|
||||
accel_logger.warn(
|
||||
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it."
|
||||
)
|
||||
accelerator_name = "cpu"
|
||||
except (RuntimeError, ImportError) as e:
|
||||
# TODO need a more decent way to detect which accelerator to use, consider using nvidia-smi command for detection
|
||||
accelerator_name = "cuda"
|
||||
pass
|
||||
|
||||
ds_set_method = "auto detect"
|
||||
|
||||
|
@ -9,7 +9,7 @@ import deepspeed
|
||||
import subprocess
|
||||
import argparse
|
||||
from .ops.op_builder.all_ops import ALL_OPS
|
||||
from .git_version_info import installed_ops, torch_info
|
||||
from .git_version_info import installed_ops, torch_info, accelerator_name
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
GREEN = '\033[92m'
|
||||
@ -51,7 +51,8 @@ def op_report(verbose=True):
|
||||
for op_name, builder in ALL_OPS.items():
|
||||
dots = "." * (max_dots - len(op_name))
|
||||
is_compatible = OKAY if builder.is_compatible(verbose) else no
|
||||
is_installed = installed if installed_ops.get(op_name, False) else no
|
||||
is_installed = installed if installed_ops.get(op_name,
|
||||
False) and accelerator_name == get_accelerator()._name else no
|
||||
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
|
||||
print(op_name, dots, is_installed, dots2, is_compatible)
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
|
@ -18,5 +18,14 @@ except ModuleNotFoundError:
|
||||
|
||||
from .ops.op_builder.all_ops import ALL_OPS
|
||||
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
accelerator_name = ""
|
||||
torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"}
|
||||
|
||||
# compatible_ops list is recreated for each launch
|
||||
from .ops.op_builder.all_ops import ALL_OPS
|
||||
|
||||
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
for op_name, builder in ALL_OPS.items():
|
||||
op_compatible = builder.is_compatible()
|
||||
compatible_ops[op_name] = op_compatible
|
||||
compatible_ops["deepspeed_not_implemented"] = False
|
||||
|
@ -7,8 +7,6 @@ from . import adam
|
||||
from . import adagrad
|
||||
from . import lamb
|
||||
from . import lion
|
||||
#from ..git_version_info_installed import installed_ops as __installed_ops__
|
||||
#if __installed_ops__['sparse_attn']:
|
||||
from . import sparse_attention
|
||||
from . import transformer
|
||||
|
||||
|
@ -56,7 +56,8 @@ class NoGatherHandle:
|
||||
self.__param = param
|
||||
|
||||
def wait(self) -> None:
|
||||
get_accelerator().current_stream().synchronize()
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
get_accelerator().current_stream().synchronize()
|
||||
self.__param.ds_status = ZeroParamStatus.AVAILABLE
|
||||
|
||||
|
||||
@ -81,7 +82,8 @@ class NoGatherCoalescedHandle:
|
||||
if self.__complete:
|
||||
return
|
||||
|
||||
get_accelerator().current_stream().synchronize()
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
get_accelerator().current_stream().synchronize()
|
||||
for param in self.__params:
|
||||
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
|
||||
param.ds_status = ZeroParamStatus.AVAILABLE
|
||||
@ -363,7 +365,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
else:
|
||||
self.dtype = torch.float
|
||||
else:
|
||||
self.dtype = dtype or torch.half
|
||||
self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported(
|
||||
) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32
|
||||
|
||||
def patch_init_and_builtins(self):
|
||||
|
||||
|
@ -30,3 +30,4 @@ for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module
|
||||
__op_builders__.append(builder)
|
||||
|
||||
ALL_OPS = {op.name: op for op in __op_builders__ if op is not None}
|
||||
accelerator_name = get_accelerator()._name
|
||||
|
@ -464,8 +464,9 @@ class OpBuilder(ABC):
|
||||
if self.name in __class__._loaded_ops:
|
||||
return __class__._loaded_ops[self.name]
|
||||
|
||||
from deepspeed.git_version_info import installed_ops, torch_info
|
||||
if installed_ops.get(self.name, False):
|
||||
from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name:
|
||||
# Ensure the op we're about to load was compiled with the same
|
||||
# torch/cuda versions we are currently using at runtime.
|
||||
self.validate_torch_version(torch_info)
|
||||
|
@ -74,8 +74,9 @@ class SYCLOpBuilder(OpBuilder):
|
||||
]
|
||||
|
||||
def load(self, verbose=True):
|
||||
from deepspeed.git_version_info import installed_ops, torch_info # noqa: F401
|
||||
if installed_ops.get(self.name, False):
|
||||
from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name # noqa: F401
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name:
|
||||
return importlib.import_module(self.absolute_name())
|
||||
else:
|
||||
return self.jit_load(verbose)
|
||||
|
8
setup.py
8
setup.py
@ -35,7 +35,7 @@ except ImportError:
|
||||
'Please visit https://pytorch.org/ to see how to properly install torch on your system.')
|
||||
|
||||
from op_builder import get_default_compute_capabilities, OpBuilder
|
||||
from op_builder.all_ops import ALL_OPS
|
||||
from op_builder.all_ops import ALL_OPS, accelerator_name
|
||||
from op_builder.builder import installed_cuda_version
|
||||
|
||||
# Fetch rocm state.
|
||||
@ -168,12 +168,9 @@ def op_enabled(op_name):
|
||||
return int(get_env_if_set(env_var, BUILD_OP_DEFAULT))
|
||||
|
||||
|
||||
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
install_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
for op_name, builder in ALL_OPS.items():
|
||||
op_compatible = builder.is_compatible()
|
||||
compatible_ops[op_name] = op_compatible
|
||||
compatible_ops["deepspeed_not_implemented"] = False
|
||||
|
||||
# If op is requested but not available, throw an error.
|
||||
if op_enabled(op_name) and not op_compatible:
|
||||
@ -280,11 +277,10 @@ with open('deepspeed/git_version_info_installed.py', 'w') as fd:
|
||||
fd.write(f"git_hash='{git_hash}'\n")
|
||||
fd.write(f"git_branch='{git_branch}'\n")
|
||||
fd.write(f"installed_ops={install_ops}\n")
|
||||
fd.write(f"compatible_ops={compatible_ops}\n")
|
||||
fd.write(f"accelerator_name='{accelerator_name}'\n")
|
||||
fd.write(f"torch_info={torch_info}\n")
|
||||
|
||||
print(f'install_requires={install_requires}')
|
||||
print(f'compatible_ops={compatible_ops}')
|
||||
print(f'ext_modules={ext_modules}')
|
||||
|
||||
# Parse README.md to make long_description for PyPI page.
|
||||
|
@ -14,6 +14,7 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
|
||||
from unit.common import preferred_dtype
|
||||
from unit.simple_model import *
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -163,13 +164,15 @@ def checkpoint_correctness_verification(config_dict,
|
||||
tmpdir,
|
||||
load_optimizer_states=False,
|
||||
load_lr_scheduler_states=False,
|
||||
fp16=True,
|
||||
train_batch=False,
|
||||
base_optimizers=[None, None],
|
||||
empty_tag=False,
|
||||
seq_dataloader=False,
|
||||
load_module_only=False):
|
||||
dtype = torch.half if fp16 else torch.float32
|
||||
load_module_only=False,
|
||||
dtype=None):
|
||||
if dtype == None:
|
||||
dtype = preferred_dtype()
|
||||
|
||||
ds_model = create_deepspeed_model(config_dict=config_dict, model=models[0], base_optimizer=base_optimizers[0])
|
||||
|
||||
if seq_dataloader:
|
||||
@ -241,7 +244,7 @@ def checkpoint_correctness_verification(config_dict,
|
||||
load_module_only=load_module_only)
|
||||
|
||||
if load_optimizer_states:
|
||||
compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)
|
||||
compare_optimizer_states(trained_model, loaded_model, hidden_dim, dtype == torch.float16)
|
||||
|
||||
if load_lr_scheduler_states:
|
||||
compare_lr_scheduler_states(trained_model, loaded_model)
|
||||
|
@ -38,8 +38,8 @@ class TestLatestCheckpoint(DistributedTest):
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=True,
|
||||
load_lr_scheduler_states=False,
|
||||
fp16=False,
|
||||
empty_tag=True)
|
||||
empty_tag=True,
|
||||
dtype=torch.float)
|
||||
|
||||
def test_missing_latest(self, tmpdir):
|
||||
config_dict = {
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import *
|
||||
@ -22,6 +23,8 @@ class TestLRSchedulerCheckpoint(DistributedTest):
|
||||
def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload):
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
if get_accelerator().device_name() == 'cpu':
|
||||
pytest.skip("CPU accelerator does not support this test.")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -35,9 +38,6 @@ class TestLRSchedulerCheckpoint(DistributedTest):
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
"cpu_offload": use_cpu_offload
|
||||
@ -51,6 +51,10 @@ class TestLRSchedulerCheckpoint(DistributedTest):
|
||||
}
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_fp16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
if zero_stage == 3:
|
||||
@ -71,6 +75,8 @@ class TestLRSchedulerCheckpoint(DistributedTest):
|
||||
def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload):
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
if get_accelerator().device_name() == 'cpu':
|
||||
pytest.skip("CPU accelerator does not support this test.")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -81,9 +87,6 @@ class TestLRSchedulerCheckpoint(DistributedTest):
|
||||
"lr": 1e-5
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
"cpu_offload": use_cpu_offload
|
||||
@ -97,6 +100,10 @@ class TestLRSchedulerCheckpoint(DistributedTest):
|
||||
}
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_fp16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
if zero_stage == 3:
|
||||
|
@ -33,10 +33,10 @@ class TestMoECheckpoint(DistributedTest):
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=True,
|
||||
load_lr_scheduler_states=False,
|
||||
fp16=config_dict["fp16"]["enabled"],
|
||||
empty_tag=True,
|
||||
base_optimizers=optimizers,
|
||||
seq_dataloader=True)
|
||||
seq_dataloader=True,
|
||||
dtype=torch.float16)
|
||||
|
||||
@pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)])
|
||||
def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states):
|
||||
@ -77,7 +77,7 @@ class TestMoECheckpoint(DistributedTest):
|
||||
tmpdir=tmpdir,
|
||||
load_optimizer_states=load_optim_states,
|
||||
load_lr_scheduler_states=False,
|
||||
fp16=config_dict["fp16"]["enabled"],
|
||||
empty_tag=True,
|
||||
base_optimizers=optimizers,
|
||||
seq_dataloader=True)
|
||||
seq_dataloader=True,
|
||||
dtype=torch.float16)
|
||||
|
@ -19,6 +19,8 @@ class TestOtherOptimizerCheckpoint(DistributedTest):
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
|
||||
def test_checkpoint_unfused_optimizer(self, tmpdir):
|
||||
#if not get_accelerator().is_fp16_supported():
|
||||
# pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -29,9 +31,6 @@ class TestOtherOptimizerCheckpoint(DistributedTest):
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "OneCycle",
|
||||
"params": {
|
||||
@ -49,6 +48,10 @@ class TestOtherOptimizerCheckpoint(DistributedTest):
|
||||
}
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_fp16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
@ -69,6 +72,8 @@ class TestOtherOptimizerCheckpoint(DistributedTest):
|
||||
load_optimizer_states=False)
|
||||
|
||||
def test_checkpoint_fused_optimizer(self, tmpdir):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -81,10 +86,11 @@ class TestOtherOptimizerCheckpoint(DistributedTest):
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 10
|
||||
@ -129,4 +135,4 @@ class TestOtherOptimizerCheckpoint(DistributedTest):
|
||||
models=models,
|
||||
hidden_dim=hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
fp16=False)
|
||||
dtype=torch.float32)
|
||||
|
@ -58,10 +58,10 @@ class TestPipelineCheckpoint(DistributedTest):
|
||||
models=models,
|
||||
hidden_dim=models[0].hidden_dim,
|
||||
tmpdir=tmpdir,
|
||||
fp16=config_dict['fp16']['enabled'],
|
||||
load_optimizer_states=True,
|
||||
load_lr_scheduler_states=True,
|
||||
train_batch=True)
|
||||
train_batch=True,
|
||||
dtype=torch.float16 if zero_stage > 0 else torch.float32)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_topo,test_topo",
|
||||
|
@ -28,15 +28,15 @@ class TestZeROCheckpoint(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
"pipeline_loading_checkpoint": True,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
with deepspeed.zero.Init():
|
||||
@ -64,16 +64,16 @@ class TestZeROCheckpoint(DistributedTest):
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"wall_clock_breakdown": True,
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
"cpu_offload": use_cpu_offload
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
if zero_stage == 3:
|
||||
@ -104,14 +104,15 @@ class TestZeROCheckpoint(DistributedTest):
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
"cpu_offload": use_cpu_offload
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
if zero_stage == 3:
|
||||
@ -134,11 +135,11 @@ class TestZeROCheckpoint(DistributedTest):
|
||||
"stage": zero_stage
|
||||
},
|
||||
"zero_allow_untested_optimizer": True,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]
|
||||
optimizers = [HybridStateOptimizer(model.parameters()) for model in models]
|
||||
@ -152,19 +153,21 @@ class TestZeROCheckpoint(DistributedTest):
|
||||
|
||||
@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3])
|
||||
def test_load_module_only(self, tmpdir, zero_stage):
|
||||
if zero_stage == 0 and get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU Accelerator does not support this test")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
if zero_stage == 3:
|
||||
@ -185,15 +188,15 @@ class ws4_model_checkpoint(DistributedFixture):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"elastic_checkpoint": elastic_save
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
ds_config["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
|
||||
@ -221,15 +224,15 @@ class TestZeROElasticCheckpoint(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"elastic_checkpoint": elastic_save
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
ds_config["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
# torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to
|
||||
@ -274,15 +277,15 @@ class TestZeROElasticCheckpoint(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"elastic_checkpoint": elastic_load
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
ds_config["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
|
||||
@ -305,14 +308,14 @@ class TestZeROSaveLoadEdgeCase(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
|
||||
@ -325,30 +328,27 @@ class TestZeROSaveLoadEdgeCase(DistributedTest):
|
||||
|
||||
@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3])
|
||||
def test_load_immediate_save(self, tmpdir, zero_stage):
|
||||
if zero_stage == 0 and get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU Accelerator does not support this test")
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
|
||||
# 1. pretrain a model and save it
|
||||
dtype = torch.half
|
||||
ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None)
|
||||
data_loader = random_dataloader(model=ds_model,
|
||||
total_samples=1,
|
||||
hidden_dim=hidden_dim,
|
||||
device=ds_model.device,
|
||||
dtype=dtype)
|
||||
data_loader = random_dataloader(model=ds_model, total_samples=1, hidden_dim=hidden_dim, device=ds_model.device)
|
||||
for _, batch in enumerate(data_loader):
|
||||
loss = ds_model(batch[0], batch[1])
|
||||
ds_model.backward(loss)
|
||||
@ -371,10 +371,6 @@ class TestZeROSaveLoadEdgeCase(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
"stage3_gather_fp16_weights_on_model_save": True,
|
||||
@ -383,6 +379,10 @@ class TestZeROSaveLoadEdgeCase(DistributedTest):
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": 4,
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
|
||||
@ -391,11 +391,7 @@ class TestZeROSaveLoadEdgeCase(DistributedTest):
|
||||
# So we config grad_accum=2 and step only once and save_16bit_model
|
||||
ds_model = create_deepspeed_model(config_dict=config_dict, model=model, base_optimizer=None)
|
||||
|
||||
data_loader = random_dataloader(model=ds_model,
|
||||
total_samples=2,
|
||||
hidden_dim=hidden_dim,
|
||||
device=ds_model.device,
|
||||
dtype=torch.half)
|
||||
data_loader = random_dataloader(model=ds_model, total_samples=2, hidden_dim=hidden_dim, device=ds_model.device)
|
||||
|
||||
batch = next(iter(data_loader))
|
||||
loss = ds_model(batch[0], batch[1])
|
||||
@ -429,15 +425,15 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"wall_clock_breakdown": True,
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
with deepspeed.zero.Init(enabled=zero_stage == 3):
|
||||
@ -460,13 +456,14 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
with deepspeed.zero.Init(enabled=zero_stage == 3):
|
||||
@ -481,14 +478,14 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
with deepspeed.zero.Init(enabled=zero_stage == 3):
|
||||
@ -504,14 +501,14 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleFrozenModel(hidden_dim, empty_grad=False)
|
||||
@ -552,14 +549,14 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
|
||||
"optimizer": {
|
||||
"type": 'Adam'
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleFrozenModel(hidden_dim, empty_grad=False)
|
||||
|
@ -441,3 +441,13 @@ class DistributedTest(DistributedExec):
|
||||
def get_test_path(filename):
|
||||
curr_path = Path(__file__).parent
|
||||
return str(curr_path.joinpath(filename))
|
||||
|
||||
|
||||
# fp16 > bf16 > fp32
|
||||
def preferred_dtype():
|
||||
if get_accelerator().is_fp16_supported():
|
||||
return torch.float16
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
@ -7,8 +7,9 @@
|
||||
|
||||
import os
|
||||
import torch
|
||||
import pytest
|
||||
from unit.common import DistributedTest
|
||||
from deepspeed.ops.op_builder import InferenceBuilder
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
@ -18,7 +19,11 @@ class TestDequantization(DistributedTest):
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
self.device = torch.device(get_accelerator().device_name(local_rank))
|
||||
|
||||
self.dequantize_func = InferenceBuilder().load().dequantize_fp16
|
||||
from deepspeed.ops.op_builder import InferenceBuilder
|
||||
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
|
||||
pytest.skip("InferenceBuilder is not implemented")
|
||||
else:
|
||||
self.dequantize_func = InferenceBuilder().load().dequantize_fp16
|
||||
|
||||
def run_dequantize_test(self, M, N, num_groups):
|
||||
weight = torch.randint(-255, 255, (M, N)).to(dtype=torch.int8, device=self.device)
|
||||
|
@ -9,7 +9,7 @@ from unit.common import DistributedTest
|
||||
from deepspeed.git_version_info import version as ds_version
|
||||
import os
|
||||
from unit.simple_model import SimpleModel
|
||||
from deepspeed.ops.op_builder import FusedAdamBuilder
|
||||
from deepspeed.ops.op_builder import FusedAdamBuilder, FusedLambBuilder
|
||||
|
||||
if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]:
|
||||
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
|
||||
@ -183,6 +183,8 @@ class TestNonElasticBatchParamsWithOverride(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self):
|
||||
if not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME]:
|
||||
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
|
@ -43,7 +43,9 @@ def cmd(user_script_fp, prompt, multi_node):
|
||||
'''I'm going to tell them "DeepSpeed is the best"'''
|
||||
])
|
||||
@pytest.mark.parametrize("multi_node", [True, False])
|
||||
def test_user_args(cmd):
|
||||
def test_user_args(cmd, multi_node):
|
||||
if multi_node and get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, err = p.communicate()
|
||||
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"
|
||||
|
@ -4,6 +4,7 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from .common import preferred_dtype
|
||||
|
||||
|
||||
class MultiOutputModel(torch.nn.Module):
|
||||
@ -28,8 +29,11 @@ def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, ta
|
||||
batch_size = model.train_micro_batch_size_per_gpu()
|
||||
|
||||
train_data = [
|
||||
torch.full(size=(total_samples, hidden_dim), fill_value=x, device=device, dtype=torch.half, requires_grad=True)
|
||||
for x in inputs
|
||||
torch.full(size=(total_samples, hidden_dim),
|
||||
fill_value=x,
|
||||
device=device,
|
||||
dtype=preferred_dtype(),
|
||||
requires_grad=True) for x in inputs
|
||||
]
|
||||
|
||||
train_label = [torch.empty(total_samples, device=device, dtype=torch.long).fill_(y) for y in targets]
|
||||
|
@ -16,10 +16,6 @@ from unit.modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderP
|
||||
from unit.modelingpreln import BertEncoder as BertEncoderPreln
|
||||
from unit.common import DistributedTest, is_rocm_pytorch
|
||||
|
||||
#if not deepspeed.ops.__installed_ops__['transformer']:
|
||||
#pytest.skip(
|
||||
# "transformer kernels are temporarily disabled because of unexplained failures",
|
||||
# allow_module_level=True)
|
||||
if torch.half not in get_accelerator().supported_dtypes():
|
||||
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
|
||||
|
||||
|
@ -62,6 +62,8 @@ def _match_outputs(ref, tgt):
|
||||
|
||||
|
||||
def _test_activation_checkpoint(module, *inputs):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
# Move to device
|
||||
module.to(get_accelerator().device_name())
|
||||
|
||||
@ -82,6 +84,8 @@ def _test_activation_checkpoint(module, *inputs):
|
||||
|
||||
|
||||
def _test_activation_checkpoint_ordering(module, expected_ordering, *inputs):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
# Move to device
|
||||
module.to(get_accelerator().device_name())
|
||||
|
||||
|
@ -7,9 +7,11 @@ unit tests for coalesced collectives
|
||||
"""
|
||||
|
||||
import torch
|
||||
import deepspeed
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import pytest
|
||||
|
||||
from unit.common import DistributedTest
|
||||
|
||||
@ -68,6 +70,9 @@ class TestAllToAllQuantReduceFallback(DistributedTest):
|
||||
def test_1d_tensor(self):
|
||||
# case 1: 1D tensor
|
||||
input = torch.zeros((10, ), dtype=torch.half, device=get_accelerator().current_device_name())
|
||||
from deepspeed.ops.op_builder import QuantizerBuilder
|
||||
if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
|
||||
pytest.skip("QuantizerBuilder is not implemented")
|
||||
output = all_to_all_quant_reduce([input], {})[0]
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
@ -80,6 +85,9 @@ class TestAllToAllQuantReduceFallback(DistributedTest):
|
||||
def test_non_divisible(self):
|
||||
# case 2: tensor size not divisible by global_world_size
|
||||
input = torch.zeros((7, 7), dtype=torch.half, device=get_accelerator().current_device_name())
|
||||
from deepspeed.ops.op_builder import QuantizerBuilder
|
||||
if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
|
||||
pytest.skip("QuantizerBuilder is not implemented")
|
||||
output = all_to_all_quant_reduce([input], {})[0]
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
|
@ -72,6 +72,8 @@ class TestCustomMethod(DistributedTest):
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported")
|
||||
def test_custom_function(self, base_config):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
test_value = 10
|
||||
|
||||
engine = self._init_engine(base_config, test_value)
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
|
||||
from deepspeed.runtime.utils import required_torch_version
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from unit.runtime.compile.util import compare_loss
|
||||
from unit.common import DistributedTest
|
||||
@ -29,6 +30,8 @@ class TestZeRO(DistributedTest):
|
||||
pytest.skip(
|
||||
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
|
||||
)
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU does not support this test yet")
|
||||
|
||||
if offload_device == OffloadDeviceEnum.nvme:
|
||||
if zero_stage != 3:
|
||||
|
@ -74,12 +74,16 @@ class TestConfigLoad(DistributedTest):
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported")
|
||||
def test_compile(self, base_config):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
engine = self._init_engine(base_config)
|
||||
self._run_model(engine)
|
||||
assert engine.is_compiled
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported")
|
||||
def test_custom_backend(self, base_config):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
global custom_backend_called
|
||||
custom_backend_called = False
|
||||
|
||||
@ -89,12 +93,16 @@ class TestConfigLoad(DistributedTest):
|
||||
assert custom_backend_called
|
||||
|
||||
def test_compile_disabled(self, base_config):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
base_config["compile"]["enabled"] = False
|
||||
engine = self._init_engine(base_config)
|
||||
self._run_model(engine)
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported")
|
||||
def test_compile_kwargs(self, base_config):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
base_config["compile"]["kwargs"] = {"mode": "default"}
|
||||
engine = self._init_engine(base_config)
|
||||
self._run_model(engine)
|
||||
@ -102,6 +110,8 @@ class TestConfigLoad(DistributedTest):
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported")
|
||||
def test_set_compile_kwargs(self, base_config):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
engine = self._init_engine(base_config)
|
||||
engine.set_torch_compile_kwargs({"mode": "default"})
|
||||
self._run_model(engine)
|
||||
@ -109,6 +119,8 @@ class TestConfigLoad(DistributedTest):
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported")
|
||||
def test_set_compiler_fn(self, base_config):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
global custom_compler_fn_called
|
||||
custom_compler_fn_called = False
|
||||
|
||||
|
@ -39,6 +39,9 @@ class TestOneBitAdamBasic(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self, dtype):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -80,6 +83,8 @@ class TestOneBitAdamExpAvgMask(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -144,6 +149,8 @@ class TestOneBitAdamCheckpointing(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self, tmpdir):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -293,6 +300,8 @@ class TestOneBitAdamCheckpointing(DistributedTest):
|
||||
assert optimizer_3.optimizer.adam_freeze_key is False
|
||||
|
||||
def test_overflow(self, tmpdir):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -343,6 +352,8 @@ class TestOneBitAdamFP16Pipeline(DistributedTest):
|
||||
world_size = 4
|
||||
|
||||
def test(self, topo_config):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"grandient_accumulation_steps": 1,
|
||||
@ -388,6 +399,8 @@ class TestZeroOneAdamBasic(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self, dtype):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -432,6 +445,8 @@ class TestZeroOneAdamExpAvgMask(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -499,6 +514,8 @@ class TestZeroOneAdamCheckpointing(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self, tmpdir):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -647,6 +664,8 @@ class TestZeroOneAdamCheckpointing(DistributedTest):
|
||||
assert "server_error" not in v, f"Incorrect server error"
|
||||
|
||||
def test_overflow(self, tmpdir):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -700,6 +719,8 @@ class TestZeroOneAdamFP16Pipeline(DistributedTest):
|
||||
world_size = 4
|
||||
|
||||
def test(self, topo_config):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"grandient_accumulation_steps": 1,
|
||||
@ -748,6 +769,8 @@ class TestOneBitLambBasic(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self, dtype):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -795,6 +818,8 @@ class TestOneBitLampExpAvgMask(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -864,6 +889,8 @@ class TestOneBitLambCheckpointing(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self, tmpdir):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -1030,6 +1057,8 @@ class TestOneBitLambCheckpointing(DistributedTest):
|
||||
assert optimizer_3.optimizer.lamb_freeze_key is False
|
||||
|
||||
def test_overflow(self, tmpdir):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -1086,6 +1115,8 @@ class TestOneBitLambFP16Pipeline(DistributedTest):
|
||||
world_size = 4
|
||||
|
||||
def test(self, topo_config):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"grandient_accumulation_steps": 1,
|
||||
@ -1131,6 +1162,8 @@ class TestCompressedAllReduceBasic(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self, tmpdir):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
from deepspeed.runtime.comm.nccl import NcclBackend
|
||||
|
||||
size = dist.get_world_size()
|
||||
|
@ -12,6 +12,7 @@ from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader
|
||||
from unit.util import bf16_required_version_check
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
class TestAdamBF16ZeroOneCycleCompatibility(DistributedTest):
|
||||
@ -299,6 +300,10 @@ class TestZeroDtypeCocktail(DistributedTest):
|
||||
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
|
||||
)
|
||||
|
||||
if comp_type == torch.float16 or comm_type == torch.float16:
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
|
||||
type_str = {torch.float16: "fp16", torch.bfloat16: "bfp16"}
|
||||
|
||||
config_dict = {
|
||||
|
@ -5,6 +5,8 @@
|
||||
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import SimpleModel
|
||||
@ -22,6 +24,9 @@ class TestFused(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_no_overflow(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -57,6 +62,8 @@ class TestFused(DistributedTest):
|
||||
expected_loss_scale *= 2
|
||||
|
||||
def test_all_overflow(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -90,6 +97,8 @@ class TestFused(DistributedTest):
|
||||
assert optim.cur_iter == (i + 1)
|
||||
|
||||
def test_some_overflow(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -147,6 +156,8 @@ class TestUnfused(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_no_overflow(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -181,6 +192,8 @@ class TestUnfused(DistributedTest):
|
||||
expected_loss_scale *= 2
|
||||
|
||||
def test_all_overflow(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -217,6 +230,8 @@ class TestUnfused(DistributedTest):
|
||||
assert optim.cur_iter == (i + 1)
|
||||
|
||||
def test_some_overflow(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
|
@ -26,6 +26,8 @@ class TestLambFP32GradClip(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -56,6 +58,8 @@ class TestLambFP16(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test__basic(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -81,6 +85,8 @@ class TestLambFP16(DistributedTest):
|
||||
model.step()
|
||||
|
||||
def test_empty_grad(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -143,6 +149,8 @@ class TestAdamwFP16Basic(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {"train_batch_size": 1, "steps_per_print": 1, "fp16": {"enabled": True}}
|
||||
hidden_dim = 10
|
||||
|
||||
@ -160,6 +168,8 @@ class TestFP16OptimizerForMoE(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test_unfused_gradnorm(self, monkeypatch):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
if not required_torch_version(min_version=1.8):
|
||||
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")
|
||||
|
||||
@ -188,6 +198,8 @@ class TestFP16OptimizerForMoE(DistributedTest):
|
||||
engine.step()
|
||||
|
||||
def test_fused_gradnorm(self, monkeypatch):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
if not required_torch_version(min_version=1.8):
|
||||
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")
|
||||
|
||||
@ -218,6 +230,8 @@ class TestFP16OptimizerForMoE(DistributedTest):
|
||||
|
||||
@pytest.mark.parametrize("fused_lamb_legacy", [(False), (True)])
|
||||
def test_lamb_gradnorm(self, monkeypatch, fused_lamb_legacy: bool):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
if not required_torch_version(min_version=1.8):
|
||||
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")
|
||||
|
||||
@ -262,6 +276,8 @@ class TestAdamwFP16EmptyGrad(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {"train_batch_size": 1, "steps_per_print": 1, "fp16": {"enabled": True}}
|
||||
hidden_dim = 10
|
||||
|
||||
@ -281,6 +297,8 @@ class TestAdamFP16ZeroOneCycleCompatibility(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, zero_stage, use_cpu_offload):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
@ -332,6 +350,8 @@ class TestZeroStaticScale(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, zero_stage, use_cpu_offload, hidden_dim=4):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
@ -375,6 +395,8 @@ class TestZeroAllowUntestedOptimizer(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, zero_stage, use_cpu_offload):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
@ -408,6 +430,8 @@ class TestZeroEmptyPartition(DistributedTest):
|
||||
world_size = 3
|
||||
|
||||
def test(self, zero_stage, use_cpu_offload):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
@ -454,6 +478,8 @@ class TestAmp(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test_adam_basic(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {"train_batch_size": 2, "steps_per_print": 1, "amp": {"enabled": True}}
|
||||
hidden_dim = 10
|
||||
|
||||
@ -467,6 +493,8 @@ class TestAmp(DistributedTest):
|
||||
model.step()
|
||||
|
||||
def test_lamb_basic(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -492,6 +520,8 @@ class TestAmp(DistributedTest):
|
||||
model.step()
|
||||
|
||||
def test_adam_O2(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -518,6 +548,8 @@ class TestAmp(DistributedTest):
|
||||
model.step()
|
||||
|
||||
def test_adam_O2_empty_grad(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -550,6 +582,8 @@ class TestZeroSupportedClientOptimizer(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, zero_stage, optimizer_constructor):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -571,6 +605,8 @@ class TestZero2ReduceScatterOff(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -610,6 +646,8 @@ class TestFP16AdamTypes(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, adam_type, torch_impl):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -642,6 +680,8 @@ class TestZero3LazyScatter(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -677,6 +717,8 @@ class TestZeroEmptyGrad(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, stage):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
import os
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import pytest
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import Curriculum_SimpleModel, SimpleModel, random_dataloader, random_dataset
|
||||
|
||||
@ -53,6 +54,8 @@ class TestDataEfficiency(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test_curriculum_learning(self):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -64,11 +67,6 @@ class TestDataEfficiency(DistributedTest):
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16
|
||||
},
|
||||
"data_efficiency": {
|
||||
"enabled": True,
|
||||
"seed": 1234,
|
||||
@ -98,6 +96,10 @@ class TestDataEfficiency(DistributedTest):
|
||||
}
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
def data_post_process(data, data_sampler_state_dict):
|
||||
assert 'dummy_metric' in data_sampler_state_dict['current_difficulties']
|
||||
@ -105,7 +107,7 @@ class TestDataEfficiency(DistributedTest):
|
||||
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
dataset = random_dataset(20, hidden_dim, torch.device('cpu'), dtype=torch.half)
|
||||
dataset = random_dataset(20, hidden_dim, torch.device('cpu'))
|
||||
model, _, data_loader, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
training_data=dataset,
|
||||
@ -128,6 +130,8 @@ class TestLegacyCurriculumScheduler(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test_fixed_discrete(self):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -139,11 +143,6 @@ class TestLegacyCurriculumScheduler(DistributedTest):
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16
|
||||
},
|
||||
"curriculum_learning": {
|
||||
"enabled": True,
|
||||
"curriculum_type": "seqlen",
|
||||
@ -156,6 +155,10 @@ class TestLegacyCurriculumScheduler(DistributedTest):
|
||||
}
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
ground_truths = {1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4}
|
||||
|
||||
@ -172,6 +175,8 @@ class TestLegacyCurriculumScheduler(DistributedTest):
|
||||
assert seqlen == true_seqlen, f"Incorrect curriculum schedule"
|
||||
|
||||
def test_fixed_linear(self):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
@ -183,11 +188,6 @@ class TestLegacyCurriculumScheduler(DistributedTest):
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16
|
||||
},
|
||||
"curriculum_learning": {
|
||||
"enabled": True,
|
||||
"curriculum_type": "seqlen",
|
||||
@ -200,6 +200,10 @@ class TestLegacyCurriculumScheduler(DistributedTest):
|
||||
}
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
ground_truths = {1: 2, 2: 4, 3: 4, 4: 6, 5: 6, 6: 8, 7: 8, 8: 10, 9: 10, 10: 10}
|
||||
|
||||
|
@ -47,9 +47,6 @@ def base_config():
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
return config_dict
|
||||
|
||||
@ -163,11 +160,19 @@ class TestConfigLoad(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_dict(self, base_config):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
model, _, _, _ = deepspeed.initialize(config=base_config, model=model, model_parameters=model.parameters())
|
||||
|
||||
def test_json(self, base_config, tmpdir):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, 'w') as fp:
|
||||
json.dump(base_config, fp)
|
||||
@ -176,6 +181,10 @@ class TestConfigLoad(DistributedTest):
|
||||
model, _, _, _ = deepspeed.initialize(config=config_path, model=model, model_parameters=model.parameters())
|
||||
|
||||
def test_hjson(self, base_config, tmpdir):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, 'w') as fp:
|
||||
hjson.dump(base_config, fp)
|
||||
@ -188,6 +197,10 @@ class TestDeprecatedDeepScaleConfig(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, base_config, tmpdir):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
config_path = create_config_from_dict(tmpdir, base_config)
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args(args='')
|
||||
@ -209,6 +222,10 @@ class TestDistInit(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, base_config):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim)
|
||||
@ -227,6 +244,12 @@ class TestInitNoOptimizer(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, base_config):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("This test timeout with CPU accelerator")
|
||||
del base_config["optimizer"]
|
||||
hidden_dim = 10
|
||||
|
||||
@ -246,6 +269,10 @@ class TestArgs(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_none_args(self, base_config):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
model = SimpleModel(hidden_dim=10)
|
||||
model, _, _, _ = deepspeed.initialize(args=None, model=model, config=base_config)
|
||||
data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device)
|
||||
@ -253,6 +280,10 @@ class TestArgs(DistributedTest):
|
||||
loss = model(batch[0], batch[1])
|
||||
|
||||
def test_no_args(self, base_config):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
model = SimpleModel(hidden_dim=10)
|
||||
model, _, _, _ = deepspeed.initialize(model=model, config=base_config)
|
||||
data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device)
|
||||
@ -264,6 +295,10 @@ class TestNoModel(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, base_config):
|
||||
if get_accelerator().is_fp16_supported():
|
||||
base_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
base_config["bf16"] = {"enabled": True}
|
||||
model = SimpleModel(hidden_dim=10)
|
||||
with pytest.raises(AssertionError):
|
||||
model, _, _, _ = deepspeed.initialize(model=None, config=base_config)
|
||||
|
@ -18,6 +18,7 @@ from deepspeed.ops.adam import FusedAdam
|
||||
from deepspeed.runtime.lr_schedules import WARMUP_LR, WarmupLR
|
||||
from deepspeed.runtime.config import ADAM_OPTIMIZER
|
||||
from deepspeed.runtime.utils import see_memory_usage, required_torch_version
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize('zero_stage', [0, 3])
|
||||
@ -30,9 +31,6 @@ class TestNoOptim(DistributedTest):
|
||||
|
||||
ds_config = {
|
||||
'train_batch_size': self.world_size,
|
||||
'fp16': {
|
||||
'enabled': True
|
||||
},
|
||||
'zero_optimization': {
|
||||
"stage": zero_stage,
|
||||
"offload_param": {
|
||||
@ -40,6 +38,10 @@ class TestNoOptim(DistributedTest):
|
||||
}
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
ds_config["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
ds_config["bf16"] = {"enabled": True}
|
||||
# 20B test
|
||||
#hidden_dim = 16 * 1024
|
||||
hidden_dim = 4
|
||||
@ -49,11 +51,7 @@ class TestNoOptim(DistributedTest):
|
||||
see_memory_usage('pre-init', force=True)
|
||||
model, _, _, _ = deepspeed.initialize(model=model, config=ds_config)
|
||||
see_memory_usage('post-init', force=True)
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=50,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device,
|
||||
dtype=torch.half)
|
||||
data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
|
||||
for batch in data_loader:
|
||||
model(batch[0], batch[1])
|
||||
see_memory_usage('post-fwds', force=True)
|
||||
@ -120,6 +118,9 @@ class TestOptimizerImplementation(DistributedTest):
|
||||
reuse_dist_env = True
|
||||
|
||||
def test(self, optimizer_extension, model_dtype, grad_accum_dtype):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
if model_dtype == 'fp16' or grad_accum_dtype == 'fp16':
|
||||
pytest.skip("fp16 is not supported")
|
||||
if optimizer_extension == 'zero1':
|
||||
zero_stage = 1
|
||||
elif optimizer_extension == 'zero2':
|
||||
|
@ -5,8 +5,9 @@
|
||||
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from pytest import approx
|
||||
from unit.common import DistributedTest
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.multi_output_model import MultiOutputModel, multi_output_dataloader
|
||||
|
||||
|
||||
@ -28,10 +29,11 @@ class TestTwoOutputModel(DistributedTest):
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
hidden_dim = 10
|
||||
weight_value = 0.1
|
||||
@ -53,7 +55,7 @@ class TestTwoOutputModel(DistributedTest):
|
||||
inputs, targets = batch[:midpoint], batch[midpoint:]
|
||||
loss_tuple = model(inputs, targets)
|
||||
|
||||
expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device)
|
||||
expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device)
|
||||
for loss in loss_tuple:
|
||||
assert loss.shape == torch.Size([])
|
||||
assert loss.item() == approx(expected_loss.item())
|
||||
@ -84,10 +86,11 @@ class TestThreeOutputModel(DistributedTest):
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
hidden_dim = 10
|
||||
weight_value = 0.1
|
||||
@ -111,7 +114,7 @@ class TestThreeOutputModel(DistributedTest):
|
||||
loss_tuple = model(inputs, targets)
|
||||
assert len(loss_tuple) == 3
|
||||
|
||||
expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device)
|
||||
expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device)
|
||||
|
||||
for loss in loss_tuple:
|
||||
assert loss.shape == torch.Size([])
|
||||
|
@ -10,6 +10,7 @@ import pytest
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import SimpleModel, random_dataloader
|
||||
from mup.shape import set_base_shapes
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("optimizer, expected_opt_class", [("MuAdam", torch.optim.Adam),
|
||||
@ -31,14 +32,15 @@ class TestMuPOptimizers(DistributedTest):
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"cpu_offload": zero_offload
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
set_base_shapes(model, None)
|
||||
|
@ -10,6 +10,7 @@ from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
|
||||
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import SimpleModel, PLD_SimpleModel, random_dataloader
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0])
|
||||
@ -39,15 +40,16 @@ class TestPLDModel(DistributedTest):
|
||||
"lr": 0.0001
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"progressive_layer_drop": {
|
||||
"enabled": True,
|
||||
"theta": theta,
|
||||
"gamma": gamma
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
model = PLD_SimpleModel(hidden_dim, empty_grad=False)
|
||||
@ -80,15 +82,16 @@ class TestNonPLDModel(DistributedTest):
|
||||
"lr": 0.0001
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"progressive_layer_drop": {
|
||||
"enabled": True,
|
||||
"theta": theta,
|
||||
"gamma": gamma
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim, empty_grad=False)
|
||||
|
@ -9,6 +9,7 @@ from unit.simple_model import UnusedParametersModel, random_dataloader
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize('ignore_unused_parameters', [False, True])
|
||||
@ -36,11 +37,11 @@ class TestStage2IgnoreUnusedParameters(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
else:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 4
|
||||
|
||||
model = UnusedParametersModel(hidden_dim=hidden_dim)
|
||||
|
@ -16,7 +16,7 @@ from torch.nn.modules.loss import L1Loss
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.utils import skip_init
|
||||
|
||||
from unit.common import DistributedTest
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.simple_model import SimpleModel, random_dataloader
|
||||
|
||||
import deepspeed
|
||||
@ -71,11 +71,11 @@ class TestZeroUnbalancedGradients(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim)
|
||||
@ -91,6 +91,8 @@ class TestZero3RepeatForwardLoop(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, mics_enabled, zero_stage=3):
|
||||
if mics_enabled and get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
# force all params to be partitioned by forcing threshold=0
|
||||
mics_shard_size = -1
|
||||
if mics_enabled:
|
||||
@ -111,11 +113,11 @@ class TestZero3RepeatForwardLoop(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 4
|
||||
|
||||
class AlbertLikeModel(torch.nn.Module):
|
||||
@ -166,11 +168,11 @@ class TestZeroToFP32(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
|
||||
@ -260,11 +262,11 @@ class TestZeroToFP32(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
|
||||
@ -366,11 +368,11 @@ class TestIncorectAllgatherBucketSize(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim)
|
||||
@ -401,11 +403,11 @@ class TestPartitionNcclAlignment(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim)
|
||||
@ -625,6 +627,8 @@ class TestZero3ParamPartitioningBase(DistributedTest):
|
||||
|
||||
@pytest.mark.parametrize("fp16_enabled", [True, False])
|
||||
def test_fp16_enabled(self, fp16_enabled):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
self._test(fp16_enabled=fp16_enabled)
|
||||
|
||||
@pytest.mark.parametrize("contiguous_gradients", [True, False])
|
||||
@ -690,11 +694,11 @@ class TestZero3ParamPartitioningBase(DistributedTest):
|
||||
"lr": 1.0
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": fp16_enabled,
|
||||
"loss_scale": 1.0,
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
cfg["fp16"] = {"enabled": True, "loss_scale": 1.0}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
cfg["bf16"] = {"enabled": True}
|
||||
|
||||
if offload_optimizer:
|
||||
cfg["zero_optimization"]["offload_optimizer"] = {
|
||||
@ -859,11 +863,11 @@ class TestZero3ParamPartitioningLargeParam(DistributedTest):
|
||||
"lr": 1.0
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 1.0,
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
ds_config["fp16"] = {"enabled": True, "loss_scale": 1.0}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
ds_config["bf16"] = {"enabled": True}
|
||||
with deepspeed.zero.Init(mem_efficient_linear=False, enabled=init_context_manager):
|
||||
model = LargeParamModel()
|
||||
ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config)
|
||||
@ -938,24 +942,24 @@ class TestZero3ParamPartitioningManyParams(DistributedTest):
|
||||
"lr": 1.0
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 1.0,
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
ds_cfg["fp16"] = {"enabled": True, "loss_scale": 1.0}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
ds_cfg["bf16"] = {"enabled": True}
|
||||
|
||||
with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=init_context_manager):
|
||||
model = ManyParamModel()
|
||||
|
||||
ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg)
|
||||
|
||||
dtype = preferred_dtype()
|
||||
for _ in range(3): # test multiple iterations to cover prefetching
|
||||
activations: List[Tensor] = ds_engine(
|
||||
torch.ones((param_sz, ), dtype=torch.float16, device=ds_engine.device))
|
||||
activations: List[Tensor] = ds_engine(torch.ones((param_sz, ), dtype=dtype, device=ds_engine.device))
|
||||
assert len(activations) == n_layers
|
||||
|
||||
partition_sz = math.ceil(param_sz / self.world_size)
|
||||
expected_activations = torch.empty(param_sz, dtype=torch.float16, device=ds_engine.device)
|
||||
expected_activations = torch.empty(param_sz, dtype=dtype, device=ds_engine.device)
|
||||
for start_idx in range(0, param_sz, partition_sz):
|
||||
expected_activations[start_idx:start_idx + partition_sz] = dist.get_rank()
|
||||
|
||||
@ -1007,11 +1011,11 @@ class TestZero3InitForParentWeightInitialization(DistributedTest):
|
||||
"lr": 1.0
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 1.0,
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
ds_cfg["fp16"] = {"enabled": True, "loss_scale": 1.0}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
ds_cfg["bf16"] = {"enabled": True}
|
||||
|
||||
with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=True):
|
||||
model = ModelWhereParentInitializesChildWeights()
|
||||
@ -1207,13 +1211,14 @@ class TestParamPartitioningSkipInit(DistributedTest):
|
||||
"lr": 1e-4
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
class SubModel(torch.nn.Module):
|
||||
@ -1284,9 +1289,6 @@ class TestZeroOffloadStage1(DistributedTest):
|
||||
"lr": 1e-4
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
"offload_optimizer": {
|
||||
@ -1294,6 +1296,10 @@ class TestZeroOffloadStage1(DistributedTest):
|
||||
}
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim)
|
||||
@ -1311,6 +1317,8 @@ class TestZero3DictFwd(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, return_type):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet")
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"steps_per_print": 1,
|
||||
@ -1320,13 +1328,14 @@ class TestZero3DictFwd(DistributedTest):
|
||||
"lr": 1e-4
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
@ -1391,11 +1400,11 @@ class TestZeroAdamOptimizerStepCount(DistributedTest):
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim, nlayers=12)
|
||||
@ -1445,13 +1454,14 @@ class TestZeroFrozenWeights(DistributedTest):
|
||||
"lr": 1e-4
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
@ -1497,9 +1507,6 @@ class TestZeroOffloadOptim(DistributedTest):
|
||||
"train_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"steps_per_print": 1,
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
"offload_optimizer": {
|
||||
@ -1508,6 +1515,10 @@ class TestZeroOffloadOptim(DistributedTest):
|
||||
},
|
||||
"zero_force_ds_cpu_optimizer": force_ds_optim,
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
hidden_dim = 10
|
||||
|
||||
model = SimpleModel(hidden_dim)
|
||||
@ -1529,15 +1540,15 @@ class TestZeroPartitionCache(DistributedTest):
|
||||
hidden_dim = 10
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 8
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"stage3_param_persistence_threshold": hidden_dim,
|
||||
},
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
if training:
|
||||
config_dict["optimizer"] = {"type": "Adam"}
|
||||
|
||||
@ -1546,13 +1557,11 @@ class TestZeroPartitionCache(DistributedTest):
|
||||
|
||||
model, _, _, _ = deepspeed.initialize(model=model, config=config_dict)
|
||||
|
||||
dtype = torch.half
|
||||
data_loader = random_dataloader(
|
||||
model=model,
|
||||
total_samples=6,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
for _, batch in enumerate(data_loader):
|
||||
@ -1576,6 +1585,8 @@ class TestEmptyParameterGroup(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_group):
|
||||
if dtype == torch.float16 and not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
model = SimpleModel(hidden_dim=4, nlayers=4)
|
||||
param_groups = [
|
||||
{
|
||||
|
@ -6,11 +6,13 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
import deepspeed
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from unit.common import DistributedTest
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.simple_model import SimpleModel
|
||||
from utils import setup_serial_env
|
||||
|
||||
@ -47,16 +49,17 @@ config = {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 138.
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"stage3_param_persistence_threshold": 1,
|
||||
}
|
||||
}
|
||||
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config["fp16"] = {"enabled": True, "loss_scale": 138.}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config["bf16"] = {"enabled": True}
|
||||
|
||||
|
||||
class TestZeroGatheredParametersFree(DistributedTest):
|
||||
world_size = 1
|
||||
@ -124,6 +127,8 @@ class TestSerialContext(DistributedTest):
|
||||
assert dist.is_initialized()
|
||||
|
||||
def test_scatter_halftype(self):
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
setup_serial_env()
|
||||
|
||||
with deepspeed.zero.Init():
|
||||
@ -248,7 +253,7 @@ class TestSerialContext(DistributedTest):
|
||||
with deepspeed.zero.GatheredParameters(net.linear1.weight):
|
||||
assert net.linear1.weight.numel() == net.dim**2
|
||||
|
||||
input = torch.rand(net.dim).to(engine.device).half()
|
||||
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
|
||||
loss = engine(input)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
@ -8,9 +8,10 @@ import torch
|
||||
import pytest
|
||||
import deepspeed
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from utils import setup_serial_env
|
||||
from unit.common import DistributedTest
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
|
||||
|
||||
class DanglingBias(torch.nn.Linear):
|
||||
@ -119,16 +120,17 @@ config = {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"loss_scale": 138.
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"stage3_param_persistence_threshold": 1,
|
||||
}
|
||||
}
|
||||
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config["fp16"] = {"enabled": True, "loss_scale": 138.}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config["bf16"] = {"enabled": True}
|
||||
|
||||
|
||||
class TestReturnParam(DistributedTest):
|
||||
world_size = 1
|
||||
@ -142,7 +144,7 @@ class TestReturnParam(DistributedTest):
|
||||
engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)
|
||||
|
||||
for _ in range(5):
|
||||
input = torch.rand(net.dim).to(engine.device).half()
|
||||
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
|
||||
loss = engine(input)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
@ -158,7 +160,7 @@ class TestReturnParam(DistributedTest):
|
||||
engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)
|
||||
|
||||
for _ in range(5):
|
||||
input = torch.rand(net.dim).to(engine.device).half()
|
||||
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
|
||||
loss = engine(input)
|
||||
assert len(net._external_params) == 1
|
||||
assert len(net.dangler._external_params) == 0
|
||||
@ -176,7 +178,7 @@ class TestReturnParam(DistributedTest):
|
||||
engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)
|
||||
|
||||
for _ in range(1):
|
||||
input = torch.rand(net.dim).to(engine.device).half()
|
||||
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
|
||||
loss = engine(input)
|
||||
if loss is not None:
|
||||
if isinstance(loss, dict):
|
||||
|
@ -6,11 +6,12 @@
|
||||
import deepspeed.comm as dist
|
||||
import torch
|
||||
|
||||
from unit.common import DistributedTest
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.simple_model import random_dataloader
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
class ChooseModuleByCounter(torch.nn.Module):
|
||||
@ -89,9 +90,6 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
"lr": 1e-6
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"stage3_prefetch_bucket_size": hidden_dim**2,
|
||||
@ -99,6 +97,10 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
"stage3_max_reuse_distance": 0,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
model = cls(hidden_dim)
|
||||
|
||||
@ -106,7 +108,7 @@ class TestSetZ3LeafModule(DistributedTest):
|
||||
set_z3_leaf_modules(model, [cls])
|
||||
assert z3_leaf_module(model)
|
||||
|
||||
run_model(model, config_dict, hidden_dim, torch.float16, requires_grad)
|
||||
run_model(model, config_dict, hidden_dim, preferred_dtype(), requires_grad)
|
||||
|
||||
def test_choose_module_by_counter(self):
|
||||
self._test_set_z3_leaf_modules(ChooseModuleByCounter, True)
|
||||
|
@ -7,7 +7,7 @@ import pytest
|
||||
import deepspeed.comm as dist
|
||||
import torch
|
||||
|
||||
from unit.common import DistributedTest
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.simple_model import random_dataloader, SimpleModel
|
||||
from unit.util import bf16_required_version_check
|
||||
|
||||
@ -18,6 +18,7 @@ from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_grad, safe
|
||||
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state
|
||||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
|
||||
from deepspeed.ops.aio import AsyncIOBuilder
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
WEIGHT_KEY = 'weight'
|
||||
FIRST_ORDER_KEY = 'exp_avg'
|
||||
@ -112,14 +113,14 @@ class TestTensorFragmentGet(DistributedTest):
|
||||
"lr": 1e-6
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"initial_scale_power": 2
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
}
|
||||
}
|
||||
if get_accelerator().is_fp16_supported():
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 2}
|
||||
elif get_accelerator().is_bf16_supported():
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
if offload_device == OffloadDeviceEnum.cpu:
|
||||
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
|
||||
@ -139,9 +140,12 @@ class TestTensorFragmentGet(DistributedTest):
|
||||
validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False)
|
||||
validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True)
|
||||
|
||||
run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_after_bwd, validate_after_step)
|
||||
run_fragmented_model(model, config_dict, hidden_dim, preferred_dtype(), validate_after_bwd,
|
||||
validate_after_step)
|
||||
|
||||
def test_bf16_fragments(self, frozen_weights):
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU accelerator does not support this test yet.")
|
||||
if frozen_weights:
|
||||
pytest.skip("TODO: Frozen weights not currently supported by BF16 Optimizer")
|
||||
|
||||
@ -302,6 +306,8 @@ class TestTensorFragmentUpdate(DistributedTest):
|
||||
}
|
||||
|
||||
if dtype == torch.float16:
|
||||
if not get_accelerator().is_fp16_supported():
|
||||
pytest.skip("fp16 is not supported")
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif dtype == torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
@ -14,6 +14,7 @@ from deepspeed.moe.layer import MoE
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
import deepspeed.comm as dist
|
||||
from .common import preferred_dtype
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
@ -262,21 +263,21 @@ class PLD_SimpleModel(SimpleModel):
|
||||
return hidden_dim
|
||||
|
||||
|
||||
def random_dataset(total_samples, hidden_dim, device, dtype=torch.half):
|
||||
def random_dataset(total_samples, hidden_dim, device, dtype=preferred_dtype()):
|
||||
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
|
||||
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
|
||||
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
|
||||
return train_dataset
|
||||
|
||||
|
||||
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
|
||||
def random_dataloader(model, total_samples, hidden_dim, device, dtype=preferred_dtype()):
|
||||
batch_size = model.train_micro_batch_size_per_gpu()
|
||||
train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
|
||||
|
||||
def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=torch.half):
|
||||
def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=preferred_dtype()):
|
||||
batch_size = model.train_micro_batch_size_per_gpu()
|
||||
train_data = torch.randn(total_samples, seq_len, hidden_dim, device=device, dtype=dtype)
|
||||
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
|
||||
|
Reference in New Issue
Block a user