DeepSpeed JIT op + PyPI support (#496)

Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
This commit is contained in:
Jeff Rasley
2020-11-12 11:51:38 -08:00
committed by GitHub
parent 0ad4fd880b
commit 31f46feee2
59 changed files with 1667 additions and 675 deletions

5
.gitignore vendored
View File

@ -10,6 +10,7 @@ build/
dist/
*.so
deepspeed.egg-info/
build.txt
# Website
docs/_site/
@ -23,3 +24,7 @@ docs/code-docs/build
# Testing data
tests/unit/saved_checkpoint/
# Dev/IDE data
.vscode
.theia

3
.gitmodules vendored
View File

@ -1,6 +1,3 @@
[submodule "third_party/apex"]
path = third_party/apex
url = https://github.com/NVIDIA/apex.git
[submodule "DeepSpeedExamples"]
path = DeepSpeedExamples
url = https://github.com/microsoft/DeepSpeedExamples

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
global-include *.cpp *.h *.cu *.tr *.cuh *.cc *.txt

View File

@ -1,4 +1,5 @@
[![Build Status](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_apis/build/status/microsoft.DeepSpeed?branchName=master)](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master)
[![PyPI version](https://badge.fury.io/py/deepspeed.svg)](https://badge.fury.io/py/deepspeed)
[![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)
[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
[![Docker Pulls](https://img.shields.io/docker/pulls/deepspeed/deepspeed)](https://hub.docker.com/r/deepspeed/deepspeed)
@ -31,29 +32,25 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
# News
* [2020/09/10] [DeepSpeed: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
* [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
* [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html)
* [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)
* [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html)
* [2020/08/07] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) is now available on-demand
* [2020/07/24] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) on August 6th, 2020
[![DeepSpeed webinar](docs/assets/images/webinar-aug2020.png)](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-Live.html)
* [2020/05/19] [ZeRO-2 & DeepSpeed: Shattering Barriers of Deep Learning Speed & Scale](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/)
* [2020/05/19] [An Order-of-Magnitude Larger and Faster Training with ZeRO-2](https://www.deepspeed.ai/news/2020/05/18/zero-stage2.html)
* [2020/05/19] [The Fastest and Most Efficient BERT Training through Optimized Transformer Kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
# Table of Contents
| Section | Description |
| --------------------------------------- | ------------------------------------------- |
| [Why DeepSpeed?](#why-deepspeed) | DeepSpeed overview |
| [Features](#features) | DeepSpeed features |
| [Further Reading](#further-reading) | DeepSpeed documentation, tutorials, etc. |
| [Contributing](#contributing) | Instructions for contributing to DeepSpeed |
| [Publications](#publications) | DeepSpeed publications |
| [Install](#installation) | Installation details |
| [Features](#features) | Feature list and overview |
| [Further Reading](#further-reading) | Documentation, tutorials, etc. |
| [Contributing](#contributing) | Instructions for contributing |
| [Publications](#publications) | Publications related to DeepSpeed |
# Why DeepSpeed?
Training advanced deep learning models is challenging. Beyond model design,
@ -65,8 +62,32 @@ a large model easily runs out of memory with pure data parallelism and it is
difficult to use model parallelism. DeepSpeed addresses these challenges to
accelerate model development *and* training.
# Features
# Installation
The quickest way to get started with DeepSpeed is via pip, this will install
the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA
versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer
to as our 'ops'. By default, all of these extensions/ops will be built
just-in-time (JIT) using [torch's JIT C++ extension loader that relies on
ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and
dynamically link them at runtime.
```bash
pip install deepspeed
```
After installation you can validate your install and see which extensions/ops
your machine is compatible with via the DeepSpeed environment report.
```bash
ds_report
```
If you would like to pre-install any of the DeepSpeed extensions/ops (instead
of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced
installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/).
# Features
Below we provide a brief feature list, see our detailed [feature
overview](https://www.deepspeed.ai/features/) for descriptions and usage.

View File

@ -43,7 +43,6 @@ jobs:
conda install -q --yes conda
conda install -q --yes pip
conda install -q --yes gxx_linux-64
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
echo "PATH=$PATH, LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'
@ -51,9 +50,8 @@ jobs:
- script: |
source activate $(conda_env)
pip install --progress-bar=off torch==$(pytorch.version) torchvision==$(torchvision.version)
#-f https://download.pytorch.org/whl/torch_stable.html
./install.sh --local_only
#python -I basic_install_test.py
pip install .[dev]
ds_report
displayName: 'Install DeepSpeed'
- script: |
@ -71,7 +69,8 @@ jobs:
- script: |
source activate $(conda_env)
pytest --durations=0 --forked --verbose -x tests/unit/
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
displayName: 'Unit tests'
# - script: |

View File

@ -1,65 +0,0 @@
import torch
import warnings
import importlib
import warnings
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
END = '\033[0m'
SUCCESS = f"{GREEN} [SUCCESS] {END}"
WARNING = f"{YELLOW} [WARNING] {END}"
FAIL = f'{RED} [FAIL] {END}'
INFO = ' [INFO]'
try:
import deepspeed
print(f"{SUCCESS} deepspeed successfully imported.")
except ImportError as err:
raise err
print(f"{INFO} torch install path: {torch.__path__}")
print(f"{INFO} torch version: {torch.__version__}, torch.cuda: {torch.version.cuda}")
print(f"{INFO} deepspeed install path: {deepspeed.__path__}")
print(
f"{INFO} deepspeed info: {deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
)
try:
apex_C = importlib.import_module('apex_C')
print(f"{SUCCESS} apex extensions successfully installed")
except Exception as err:
print(f'{WARNING} apex extensions are not installed')
try:
from apex.optimizers import FP16_Optimizer
print(f"{INFO} using old-style apex")
except ImportError:
print(f"{INFO} using new-style apex")
try:
importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
print(f'{SUCCESS} fused lamb successfully installed.')
except Exception as err:
print(f"{WARNING} fused lamb is NOT installed.")
try:
importlib.import_module('deepspeed.ops.transformer.transformer_cuda')
print(f'{SUCCESS} transformer kernels successfully installed.')
except Exception as err:
print(f'{WARNING} transformer kernels are NOT installed.')
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils')
import triton
print(f'{SUCCESS} sparse attention successfully installed.')
except ImportError:
print(f'{WARNING} sparse attention is NOT installed.')
try:
importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
print(f'{SUCCESS} cpu-adam (used by ZeRO-offload) successfully installed.')
except ImportError:
print(f'{WARNING} cpu-adam (used by ZeRO-offload) is NOT installed.')

6
bin/ds_report Normal file
View File

@ -0,0 +1,6 @@
#!/usr/bin/env python
from deepspeed.env_report import main
if __name__ == '__main__':
main()

14
csrc/adam/compat.h Normal file
View File

@ -0,0 +1,14 @@
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

19
csrc/adam/custom_cuda_kernel.cu Normal file → Executable file
View File

@ -4,30 +4,15 @@
__global__ void param_update_kernel(const float* input, __half* output, int size)
{
const float4* input_cast = reinterpret_cast<const float4*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < size) {
float4 data = input_cast[id];
float2 cast_data;
__half* output_h = reinterpret_cast<__half*>(&cast_data);
output_h[0] = (__half)data.x;
output_h[1] = (__half)data.y;
output_h[2] = (__half)data.z;
output_h[3] = (__half)data.w;
output_cast[id] = cast_data;
}
if (id < size) { output[id] = (__half)input[id]; }
}
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
{
int threads = 512;
int threads = 1024;
size /= 4;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);

View File

@ -0,0 +1,20 @@
#include <torch/extension.h>
void multi_tensor_adam_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("multi_tensor_adam",
&multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
}

View File

@ -0,0 +1,163 @@
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
adamMode_t mode,
const float decay)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay)
{
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,127 @@
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"
#include <assert.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
void* addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
volatile int* noop_flag,
T tl,
U callable,
ArgTypes... args)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} else {
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}

View File

@ -0,0 +1,25 @@
/*
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <torch/csrc/utils/tensor_flatten.h>
#include <torch/extension.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
at::Tensor flatten(std::vector<at::Tensor> tensors)
{
return torch::utils::flatten_dense_tensors(tensors);
}
std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
{
return torch::utils::unflatten_dense_tensors(flat, tensors);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("flatten", &flatten, "Flatten dense tensors");
m.def("unflatten", &unflatten, "Unflatten dense tensors");
}

View File

@ -17,24 +17,19 @@ from .utils import log_dist
from .pipe import PipelineModule
try:
from .git_version_info import version, git_hash, git_branch
except ImportError:
version = "0.0.0+unknown"
git_hash = None
git_branch = None
from .git_version_info import version, git_hash, git_branch
def _parse_version(version_str):
'''Parse a version string and extract the major, minor, and patch versions.'''
import re
matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str)
return int(matched.group(1)), int(matched.group(2)), int(matched.group(3))
# Export version information
version, __version_tag__ = version.split('+')
__version_major__ = int(version.split('.')[0])
__version_minor__ = int(version.split('.')[1])
__version_patch__ = int(version.split('.')[2])
__version__ = '.'.join(
map(str,
[__version_major__,
__version_minor__,
__version_patch__]))
__version__ = f"{__version__}+{__version_tag__}"
__version__ = version
__version_major__, __version_minor__, __version_patch__ = _parse_version(__version__)
__git_hash__ = git_hash
__git_branch__ = git_branch

107
deepspeed/env_report.py Normal file
View File

@ -0,0 +1,107 @@
import torch
import deepspeed
import subprocess
from .ops.op_builder import ALL_OPS
from .git_version_info import installed_ops, torch_info
from .ops import __compatible_ops__ as compatible_ops
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
END = '\033[0m'
SUCCESS = f"{GREEN} [SUCCESS] {END}"
OKAY = f"{GREEN}[OKAY]{END}"
WARNING = f"{YELLOW}[WARNING]{END}"
FAIL = f'{RED}[FAIL]{END}'
INFO = '[INFO]'
color_len = len(GREEN) + len(END)
okay = f"{GREEN}[OKAY]{END}"
warning = f"{YELLOW}[WARNING]{END}"
def op_report():
max_dots = 23
max_dots2 = 11
h = ["op name", "installed", "compatible"]
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print("DeepSpeed C++/CUDA extension op report")
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print("NOTE: Ops not installed will be just-in-time (JIT) compiled at\n"
" runtime if needed. Op compatibility means that your system\n"
" meet the required dependencies to JIT install the op.")
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print("JIT compiled ops requires ninja")
ninja_status = OKAY if ninja_installed() else FAIL
print('ninja', "." * (max_dots - 5), ninja_status)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
print(h[0], "." * (max_dots - len(h[0])), h[1], "." * (max_dots2 - len(h[1])), h[2])
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
installed = f"{GREEN}[YES]{END}"
no = f"{YELLOW}[NO]{END}"
for op_name, builder in ALL_OPS.items():
dots = "." * (max_dots - len(op_name))
is_compatible = OKAY if builder.is_compatible() else no
is_installed = installed if installed_ops[op_name] else no
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) -
(len(is_installed) - color_len))
print(op_name, dots, is_installed, dots2, is_compatible)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
def ninja_installed():
try:
import ninja
except ImportError:
return False
return True
def nvcc_version():
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
try:
output = subprocess.check_output([cuda_home + "/bin/nvcc",
"-V"],
universal_newlines=True)
except FileNotFoundError:
return f"{RED} [FAIL] nvcc missing {END}"
output_split = output.split()
release_idx = output_split.index("release")
release = output_split[release_idx + 1].replace(',', '').split(".")
return ".".join(release)
def debug_report():
max_dots = 33
report = [
("torch install path",
torch.__path__),
("torch version",
torch.__version__),
("torch cuda version",
torch.version.cuda),
("nvcc version",
nvcc_version()),
("deepspeed install path",
deepspeed.__path__),
("deepspeed info",
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, cuda {torch_info['cuda_version']}"),
]
print("DeepSpeed general environment info:")
for name, value in report:
print(name, "." * (max_dots - len(name)), value)
def main():
op_report()
debug_report()
if __name__ == "__main__":
main()

View File

@ -3,12 +3,11 @@ try:
from .git_version_info_installed import *
except ModuleNotFoundError:
# Will be missing from checkouts that haven't been installed (e.g., readthedocs)
version = '0.3.0+[none]'
version = open('version.txt', 'r').read().strip()
git_hash = '[none]'
git_branch = '[none]'
installed_ops = {
'lamb': False,
'transformer': False,
'sparse-attn': False,
'cpu-adam': False
}
from .ops.op_builder import ALL_OPS
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
torch_info = {'version': "0.0", "cuda_version": "0.0"}

View File

@ -1,7 +1,6 @@
from ..git_version_info import installed_ops as __installed_ops__
from . import adam
from . import lamb
from . import sparse_attention
from . import transformer
if __installed_ops__['sparse-attn']:
from . import sparse_attention
if __installed_ops__['cpu-adam']:
from . import adam
from ..git_version_info import compatible_ops as __compatible_ops__

View File

@ -1 +1,2 @@
from .cpu_adam import DeepSpeedCPUAdam
from .fused_adam import FusedAdam

View File

@ -4,9 +4,9 @@ Copyright 2020 The Microsoft DeepSpeed Team
import math
import torch
import importlib
ds_opt_adam = None
import time
from pathlib import Path
from ..op_builder import CPUAdamBuilder
class DeepSpeedCPUAdam(torch.optim.Optimizer):
@ -67,15 +67,15 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
global ds_opt_adam
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
ds_opt_adam.create_adam(self.opt_id,
lr,
betas[0],
betas[1],
eps,
weight_decay,
adamw_mode)
self.ds_opt_adam = CPUAdamBuilder().load()
self.ds_opt_adam.create_adam(self.opt_id,
lr,
betas[0],
betas[1],
eps,
weight_decay,
adamw_mode)
def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
@ -101,18 +101,20 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(
p.data,
memory_format=torch.preserve_format)
state['exp_avg'] = torch.zeros_like(p.data,
dtype=p.dtype,
device='cpu')
#memory_format=torch.preserve_format)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(
p.data,
memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p.data,
dtype=p.dtype,
device='cpu')
#memory_format=torch.preserve_format)
state['step'] += 1
if fp16_param_groups is not None:
ds_opt_adam.adam_update_copy(
self.ds_opt_adam.adam_update_copy(
self.opt_id,
state['step'],
group['lr'],
@ -122,11 +124,11 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
state['exp_avg_sq'],
fp16_param_groups[group_id][param_id].data)
else:
ds_opt_adam.adam_update(self.opt_id,
state['step'],
group['lr'],
p.data,
p.grad.data,
state['exp_avg'],
state['exp_avg_sq'])
self.ds_opt_adam.adam_update(self.opt_id,
state['step'],
group['lr'],
p.data,
p.grad.data,
state['exp_avg'],
state['exp_avg_sq'])
return loss

View File

@ -0,0 +1,182 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
'''
import torch
import importlib
from .multi_tensor_apply import MultiTensorApply
multi_tensor_applier = MultiTensorApply(2048 * 32)
from ..op_builder import FusedAdamBuilder
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Currently GPU-only.
This version of fused Adam implements 2 fusions.
* Fusion of the Adam update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
adam_w_mode=True,
weight_decay=0.,
amsgrad=False,
set_grad_none=True):
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
fused_adam_cuda = FusedAdamBuilder().load()
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_adam = fused_adam_cuda.multi_tensor_adam
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedAdam, self).zero_grad()
def step(self,
closure=None,
grads=None,
output_params=None,
scale=None,
grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
"""
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
raise RuntimeError(
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
)
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
for p in group['params']:
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
v_16.append(state['exp_avg_sq'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
if (len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_16,
p_16,
m_16,
v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
if (len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_32,
p_32,
m_32,
v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
return loss

View File

@ -0,0 +1,15 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex, commit a109f85
'''
import torch
class MultiTensorApply(object):
def __init__(self, chunk_size):
self.chunk_size = chunk_size
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)

1
deepspeed/ops/csrc Symbolic link
View File

@ -0,0 +1 @@
../../csrc

View File

@ -1 +1 @@
from deepspeed.ops.lamb.fused_lamb import FusedLamb
from .fused_lamb import FusedLamb

View File

@ -5,8 +5,8 @@ Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
'''
import types
import importlib
import torch
from ..op_builder import FusedLambBuilder
class FusedLamb(torch.optim.Optimizer):
@ -48,15 +48,7 @@ class FusedLamb(torch.optim.Optimizer):
max_coeff=10.0,
min_coeff=0.01,
amsgrad=False):
global fused_lamb_cuda
try:
fused_lamb_cuda = importlib.import_module(
"deepspeed.ops.lamb.fused_lamb_cuda")
except ImportError as err:
print(
"Unable to import Lamb cuda extension, please build DeepSpeed with cuda/cpp extensions."
)
raise err
self.fused_lamb_cuda = FusedLambBuilder().load()
if amsgrad:
raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
@ -173,22 +165,22 @@ class FusedLamb(torch.optim.Optimizer):
out_p = torch.tensor(
[],
dtype=torch.float) if output_param is None else output_param
lamb_coeff = fused_lamb_cuda.lamb(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
max_coeff,
min_coeff,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
lamb_coeff = self.fused_lamb_cuda.lamb(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
max_coeff,
min_coeff,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
self.lamb_coeffs.append(lamb_coeff)
return loss

1
deepspeed/ops/op_builder Symbolic link
View File

@ -0,0 +1 @@
../../op_builder

View File

@ -2,13 +2,12 @@
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
import importlib
import warnings
try:
import triton
except ImportError:
warnings.warn("Unable to import triton, sparse attention will not be accessible")
import torch
import math
from deepspeed.ops.sparse_attention.trsrc import matmul
from .trsrc import matmul
from ..op_builder import SparseAttnBuilder
triton = None
##############
@ -27,6 +26,9 @@ class _sparse_matmul(torch.autograd.Function):
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
global triton
if triton is None:
triton = importlib.import_module('triton')
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
@ -83,11 +85,18 @@ class _sparse_matmul(torch.autograd.Function):
##########################
# SPARSE = DENSE x DENSE #
##########################
cpp_utils = importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils')
sdd_segment = cpp_utils.sdd_segment
cpp_utils = None
sdd_segment = None
@staticmethod
def _load_utils():
if _sparse_matmul.cpp_utils is None:
_sparse_matmul.cpp_utils = SparseAttnBuilder().load()
_sparse_matmul.sdd_segment = _sparse_matmul.cpp_utils.sdd_segment
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
_sparse_matmul._load_utils()
start_width = 64 // block
segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
@ -118,6 +127,10 @@ class _sparse_matmul(torch.autograd.Function):
packs,
bench,
time):
global triton
if triton is None:
triton = importlib.import_module('triton')
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
@ -332,6 +345,10 @@ class _sparse_matmul(torch.autograd.Function):
packs,
bench,
time):
global triton
if triton is None:
triton = importlib.import_module('triton')
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
@ -413,6 +430,10 @@ class _sparse_matmul(torch.autograd.Function):
packs,
bench,
time):
global triton
if triton is None:
triton = importlib.import_module('triton')
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]

View File

@ -2,17 +2,17 @@
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
import warnings
try:
import triton
except ImportError:
warnings.warn("Unable to import triton, sparse attention will not be accessible")
import importlib
import torch
import math
from deepspeed.ops.sparse_attention.trsrc import softmax_fwd, softmax_bwd
from .trsrc import softmax_fwd, softmax_bwd
fwd_kernels = dict()
bwd_kernels = dict()
# Delay importing triton unless we need it
triton = None
class _sparse_softmax(torch.autograd.Function):
@ -52,6 +52,10 @@ class _sparse_softmax(torch.autograd.Function):
apply_attn_mask,
kp_mask_mode,
attn_mask_mode):
global triton
if triton is None:
triton = importlib.import_module('triton')
if max_k >= 32768:
raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented')
@ -112,6 +116,10 @@ class _sparse_softmax(torch.autograd.Function):
maxlut,
bench,
time):
global triton
if triton is None:
triton = importlib.import_module('triton')
apply_scale = False if scale == 1.0 else True
# handle None rpe
@ -180,6 +188,10 @@ class _sparse_softmax(torch.autograd.Function):
@staticmethod
def backward(ctx, dx):
global triton
if triton is None:
triton = importlib.import_module('triton')
# retrieve from context
x, lut = ctx.saved_tensors
# run kernel

View File

@ -1 +1 @@
from deepspeed.ops.transformer.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig

View File

@ -8,6 +8,8 @@ import torch
from torch import nn
from torch.autograd import Function
from ..op_builder import TransformerBuilder, StochasticTransformerBuilder
# Cuda modules will be imported if needed
transformer_cuda_module = None
stochastic_transformer_cuda_module = None
@ -483,19 +485,12 @@ class DeepSpeedTransformerLayer(nn.Module):
self.norm_w = initial_weights[7]
self.norm_b = initial_biases[7]
# Import cuda modules if needed
# Load cuda modules if needed
global transformer_cuda_module, stochastic_transformer_cuda_module
if transformer_cuda_module is None or stochastic_transformer_cuda_module is None:
try:
transformer_cuda_module = importlib.import_module(
"deepspeed.ops.transformer.transformer_cuda")
stochastic_transformer_cuda_module = importlib.import_module(
"deepspeed.ops.transformer.stochastic_transformer_cuda")
except ImportError as err:
print(
"Unable to import transformer cuda extension, please build DeepSpeed with cuda/cpp extensions."
)
raise err
if transformer_cuda_module is None and not self.config.stochastic_mode:
transformer_cuda_module = TransformerBuilder().load()
if stochastic_transformer_cuda_module is None and self.config.stochastic_mode:
stochastic_transformer_cuda_module = StochasticTransformerBuilder().load()
# create the layer in cuda kernels.
cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module

View File

@ -7,8 +7,6 @@ import torch
import warnings
import torch.distributed as dist
import apex
from apex import amp
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter
@ -36,22 +34,17 @@ from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from .utils import ensure_directory_exists
from ..ops.op_builder import UtilsBuilder
from ..ops.adam import DeepSpeedCPUAdam
from ..ops.adam import FusedAdam
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
try:
from apex_C import flatten
from apex_C import unflatten
from apex import amp
except ImportError:
try:
_ = warned_flatten
except NameError:
logger.warning(
"Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
# Fail silently so we don't spam logs unnecessarily if user isn't using amp
pass
def split_half_float_double_csr(tensors):
@ -201,6 +194,11 @@ class DeepSpeedEngine(Module):
if self.dump_state():
print_configuration(self, 'DeepSpeedEngine')
# Load pre-installed or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI
@ -558,6 +556,12 @@ class DeepSpeedEngine(Module):
amp_params = self.amp_params()
if self.global_rank == 0:
logger.info(f"Initializing AMP with these params: {amp_params}")
try:
logger.info("Initializing Apex amp from: {}".format(amp.__path__))
except NameError:
# If apex/amp is available it will be imported above
raise RuntimeError(
"Unable to import apex/amp, please make sure it is installed")
self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
self._broadcast_model()
elif self.fp16_enabled():
@ -584,17 +588,18 @@ class DeepSpeedEngine(Module):
# T|F T F torch.optim.Adam
# T F T|F DeepSpeedCPUAdam(adam_w_mode)
# F F T|F FusedAdam(adam_w_mode)
if torch_adam and adam_w_mode:
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
elif torch_adam and not adam_w_mode:
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
elif self.zero_cpu_offload() and not torch_adam:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if torch_adam:
if adam_w_mode:
optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters)
else:
optimizer = torch.optim.Adam(model_parameters,
**optimizer_parameters)
elif self.zero_cpu_offload():
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=adam_w_mode)
elif not self.zero_cpu_offload() and not torch_adam:
from apex.optimizers.fused_adam import FusedAdam
else:
optimizer_parameters[ADAM_W_MODE_PARAM] = adam_w_mode
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
@ -614,8 +619,7 @@ class DeepSpeedEngine(Module):
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if isinstance(optimizer,
apex.optimizers.FusedAdam) or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
@ -1072,7 +1076,7 @@ class DeepSpeedEngine(Module):
ranks=[0])
def allreduce_bucket(self, bucket):
tensor = flatten(bucket)
tensor = self.flatten(bucket)
tensor_to_allreduce = tensor
@ -1100,7 +1104,7 @@ class DeepSpeedEngine(Module):
def allreduce_and_copy(self, small_bucket):
allreduced = self.allreduce_bucket(small_bucket)
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self, bucket, numel_per_bucket=500000000):

View File

@ -15,26 +15,15 @@ import collections
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.utils import logger
from ...ops.op_builder import UtilsBuilder
#Toggle this to true to enable correctness test
#with gradient partitioning and without
pg_correctness_test = False
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
logger.warning(
"apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
def input(msg):
return
@ -132,6 +121,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1):
# Load pre-installed or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
@ -1053,7 +1047,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
rank = None
tensor = flatten(bucket)
tensor = self.flatten(bucket)
tensor_to_allreduce = tensor
@ -1095,7 +1089,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
with torch.cuda.stream(stream):
allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
def allreduce_no_retain(self,

View File

@ -1,8 +1,8 @@
import torch
import torch.distributed as dist
import apex
from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adam import FusedAdam
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
@ -23,11 +23,14 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
return my_group
ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam,
apex.optimizers.FusedAdam,
DeepSpeedCPUAdam
]
ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, FusedAdam, DeepSpeedCPUAdam]
# Add apex FusedAdam to supported list if apex is installed
try:
import apex
ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
except ImportError:
pass
def is_zero_supported_optimizer(optimizer):

View File

@ -30,6 +30,7 @@ collections:
output: true
permalink: /:collection/:path/
order:
- advanced-install.md
- getting-started.md
- azure.md
- cifar-10.md

View File

@ -0,0 +1,86 @@
---
title: "Installation Details"
date: 2020-10-28
---
The quickest way to get started with DeepSpeed is via pip, this will install
the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA
versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer
to as our 'ops'. By default, all of these extensions/ops will be built
just-in-time (JIT) using [torch's JIT C++ extension loader that relies on
ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and
dynamically link them at runtime.
```bash
pip install deepspeed
```
After installation you can validate your install and see which ops your machine
is compatible with via the DeepSpeed environment report with `ds_report` or
`python -m deepspeed.env_report`. We've found this report useful when debugging
DeepSpeed install or compatibility issues.
```bash
ds_report
```
## Install DeepSpeed from source
After cloning the DeepSpeed repo from github you can install DeepSpeed in
JIT mode via pip (see below). This install should complete
quickly since it is not compiling any C++/CUDA source files.
```bash
pip install .
```
For installs spanning multiple nodes we find it useful to install DeepSpeed
using the
[install.sh](https://github.com/microsoft/DeepSpeed/blob/master/install.sh)
script in the repo. This will build a python wheel locally and copy it to all
the nodes listed in your hostfile (either given via --hostfile, or defaults to
/job/hostfile).
## Pre-install DeepSpeed Ops
Sometimes we have found it useful to pre-install either some or all DeepSpeed
C++/CUDA ops instead of using the JIT compiled path. In order to support
pre-installation we introduce build environment flags to turn on/off building
specific ops.
You can indicate to our installer (either install.sh or pip install) that you
want to attempt to install all of our ops by setting the `DS_BUILD_OPS`
environment variable to 1, for example:
```bash
DS_BUILD_OPS=1 pip install .
```
We will only install any ops that are compatible with your machine, for more
details on which ops are compatible with your system please try our `ds_report`
tool described above.
If you want to install only a specific op (e.g., FusedLamb) you can view the op
specific build environment variable (set as `BUILD_VAR`) in the corresponding
op builder class in the
[op\_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder)
directory. For example to install only the Fused Lamb op you would install via:
```bash
DS_BUILD_FUSED_LAMB=1 pip install .
```
## Feature specific dependencies
Some DeepSpeed features require specific dependencies outside of the general
dependencies of DeepSpeed.
* Python package dependencies per feature/op please
see our [requirements
directory](https://github.com/microsoft/DeepSpeed/tree/master/requirements).
* We attempt to keep the system level dependencies to a minimum, however some features do require special system-level packages. Please see our `ds_report` tool output to see if you are missing any system-level packages for a given feature.
## Pre-compiled DeepSpeed builds from PyPI
Coming soon

View File

@ -7,9 +7,9 @@ date: 2020-05-15
## Installation
* Installing is as simple as `pip install deepspeed`, [see more details](/tutorials/advanced-install/).
* Please see our [Azure tutorial](/tutorials/azure/) to get started with DeepSpeed on Azure!
* If you're not on Azure, we recommend using our docker image via `docker pull deepspeed/deepspeed:latest` which contains a pre-installed version of DeepSpeed and all the necessary dependencies.
* If you want to install DeepSpeed manually, we provide an install script `install.sh` to help install on a local machine or across an entire cluster.
## Writing DeepSpeed Models
DeepSpeed model training is accomplished using the DeepSpeed engine. The engine

View File

@ -28,8 +28,9 @@ initiative to enable next-generation AI capabilities at scale, where you can fin
information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale).
# What's New?
* [2020/10/28] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
* [DeepSpeed: Extreme-scale model training for everyone]({{ site.press_release_v3 }})
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone]({{ site.press_release_v3 }})
* [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
* [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html)
* [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)

View File

@ -15,16 +15,13 @@ By default will install deepspeed and all third party dependecies accross all ma
hostfile (hostfile: /job/hostfile). If no hostfile exists, will only install locally
[optional]
-d, --deepspeed_only Install only deepspeed and no third party dependencies
-t, --third_party_only Install only third party dependencies and not deepspeed
-l, --local_only Install only on local machine
-s, --pip_sudo Run pip install with sudo (default: no sudo)
-r, --allow_sudo Allow script to be run by root (probably don't want this, instead use --pip_sudo)
-n, --no_clean Do not clean prior build state, by default prior build files are removed before building wheels
-m, --pip_mirror Use the specified pip mirror (default: the default pip mirror)
-H, --hostfile Path to MPI-style hostfile (default: /job/hostfile)
-a, --apex_commit Install a specific commit hash of apex, instead of the one deepspeed points to
-k, --skip_requirements Skip installing DeepSpeed requirements
-v, --verbose Verbose logging
-h, --help This help text
"""
}
@ -42,27 +39,12 @@ apex_commit=""
skip_requirements=0
allow_sudo=0
no_clean=0
verbose=0
while [[ $# -gt 0 ]]
do
key="$1"
case $key in
-d|--deepspeed_only)
deepspeed_install=1;
third_party_install=0;
ds_only=1;
shift
;;
-t|--third_party_only)
deepspeed_install=0;
third_party_install=1;
tp_only=1;
shift
;;
-l|--local_only)
local_only=1;
shift
;;
-s|--pip_sudo)
pip_sudo=1;
shift
@ -72,13 +54,8 @@ case $key in
shift
shift
;;
-a|--apex_commit)
apex_commit=$2;
shift
shift
;;
-k|--skip_requirements)
skip_requirements=1;
-v|--verbose)
verbose=1;
shift
;;
-r|--allow_sudo)
@ -126,12 +103,18 @@ if [ "$ds_only" == "1" ] && [ "$tp_only" == "1" ]; then
exit 1
fi
if [ "$verbose" == "1" ]; then
VERBOSE="-v"
else
VERBOSE=""
fi
rm_if_exist() {
echo "Attempting to remove $1"
if [ -f $1 ]; then
rm -v $1
rm $VERBOSE $1
elif [ -d $1 ]; then
rm -vr $1
rm -r $VERBOSE $1
fi
}
@ -141,10 +124,6 @@ if [ "$no_clean" == "0" ]; then
rm_if_exist dist
rm_if_exist build
rm_if_exist deepspeed.egg-info
# remove apex build files
rm_if_exist third_party/apex/dist
rm_if_exist third_party/apex/build
rm_if_exist third_party/apex/apex.egg-info
fi
if [ "$pip_sudo" == "1" ]; then
@ -154,60 +133,25 @@ else
fi
if [ "$pip_mirror" != "" ]; then
PIP_INSTALL="pip install -v -i $pip_mirror"
PIP_INSTALL="pip install $VERBOSE -i $pip_mirror"
else
PIP_INSTALL="pip install -v"
PIP_INSTALL="pip install $VERBOSE"
fi
if [ ! -f $hostfile ]; then
echo "No hostfile exists at $hostfile, installing locally"
local_only=1
fi
if [ "$skip_requirements" == "0" ]; then
# Ensure dependencies are installed locally
$PIP_SUDO $PIP_INSTALL -r requirements/requirements.txt
fi
# Build wheels
if [ "$third_party_install" == "1" ]; then
echo "Checking out sub-module(s)"
git submodule update --init --recursive
echo "Building apex wheel"
cd third_party/apex
if [ "$apex_commit" != "" ]; then
echo "Installing a non-standard version of apex at commit: $apex_commit"
git fetch
git checkout $apex_commit
fi
python setup.py -v --cpp_ext --cuda_ext bdist_wheel
cd -
echo "Installing apex locally so that deepspeed will build"
$PIP_SUDO pip uninstall -y apex
$PIP_SUDO $PIP_INSTALL third_party/apex/dist/apex*.whl
fi
if [ "$deepspeed_install" == "1" ]; then
echo "Building deepspeed wheel"
python setup.py -v bdist_wheel
fi
echo "Building deepspeed wheel"
python setup.py $VERBOSE bdist_wheel
if [ "$local_only" == "1" ]; then
if [ "$deepspeed_install" == "1" ]; then
echo "Installing deepspeed"
$PIP_SUDO pip uninstall -y deepspeed
$PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl
# -I to exclude local directory files
python -I basic_install_test.py
if [ $? == 0 ]; then
echo "Installation is successful"
else
echo "Installation failed"
fi
fi
echo "Installing deepspeed"
$PIP_SUDO pip uninstall -y deepspeed
$PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl
ds_report
else
local_path=`pwd`
if [ -f $hostfile ]; then
@ -216,28 +160,16 @@ else
echo "hostfile not found, cannot proceed"
exit 1
fi
export PDSH_RCMD_TYPE=ssh;
export PDSH_RCMD_TYPE=ssh
tmp_wheel_path="/tmp/deepspeed_wheels"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi"
pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/
if [ "$skip_requirements" == "0" ]; then
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
fi
if [ "$third_party_install" == "1" ]; then
pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex"
pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/apex*.whl"
pdsh -w $hosts 'python -c "import apex"'
fi
if [ "$deepspeed_install" == "1" ]; then
echo "Installing deepspeed"
pdsh -w $hosts "$PIP_SUDO pip uninstall -y deepspeed"
pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/
pdcp -w $hosts basic_install_test.py $tmp_wheel_path/
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl"
pdsh -w $hosts "python $tmp_wheel_path/basic_install_test.py"
echo "Installation is successful"
fi
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl $tmp_wheel_path/basic_install_test.py; rmdir $tmp_wheel_path; fi"
echo "Installing deepspeed"
pdsh -w $hosts "$PIP_SUDO pip uninstall -y deepspeed"
pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl"
pdsh -w $hosts "ds_report"
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rmdir $tmp_wheel_path; fi"
fi

20
op_builder/__init__.py Normal file
View File

@ -0,0 +1,20 @@
from .cpu_adam import CPUAdamBuilder
from .fused_adam import FusedAdamBuilder
from .fused_lamb import FusedLambBuilder
from .sparse_attn import SparseAttnBuilder
from .transformer import TransformerBuilder
from .stochastic_transformer import StochasticTransformerBuilder
from .utils import UtilsBuilder
# TODO: infer this list instead of hard coded
# List of all available ops
__op_builders__ = [
CPUAdamBuilder(),
FusedAdamBuilder(),
FusedLambBuilder(),
SparseAttnBuilder(),
TransformerBuilder(),
StochasticTransformerBuilder(),
UtilsBuilder()
]
ALL_OPS = {op.name: op for op in __op_builders__}

245
op_builder/builder.py Normal file
View File

@ -0,0 +1,245 @@
import os
import time
import torch
import importlib
from pathlib import Path
import subprocess
from abc import ABC, abstractmethod
YELLOW = '\033[93m'
END = '\033[0m'
WARNING = f"{YELLOW} [WARNING] {END}"
DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
def assert_no_cuda_mismatch():
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
# Ensure there is not a cuda version mismatch between torch and nvcc compiler
output = subprocess.check_output([cuda_home + "/bin/nvcc",
"-V"],
universal_newlines=True)
output_split = output.split()
release_idx = output_split.index("release")
release = output_split[release_idx + 1].replace(',', '').split(".")
# Ignore patch versions, only look at major + minor
installed_cuda_version = ".".join(release[:2])
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
# This is a show-stopping error, should probably not proceed past this
if installed_cuda_version != torch_cuda_version:
raise Exception(
f"Installed CUDA version {installed_cuda_version} does not match the "
f"version torch was compiled with {torch.version.cuda}, unable to compile "
"cuda/cpp extensions without a matching cuda version.")
def assert_torch_info(torch_info):
install_torch_version = torch_info['version']
install_cuda_version = torch_info['cuda_version']
current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
current_torch_version = ".".join(torch.__version__.split('.')[:2])
if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version:
raise RuntimeError(
"PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
"with a different version than what is being used at runtime. Please re-install "
f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:"
f"torch={current_torch_version}, cuda={current_cuda_version}")
class OpBuilder(ABC):
def __init__(self, name):
self.name = name
self.jit_mode = False
@abstractmethod
def absolute_name(self):
'''
Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
will be installed as something like: deepspeed/ops/adam/cpu_adam.so
'''
pass
@abstractmethod
def sources(self):
'''
Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
'''
pass
def include_paths(self):
'''
Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
'''
return []
def nvcc_args(self):
'''
Returns optional list of compiler flags to forward to nvcc when building CUDA sources
'''
return []
def cxx_args(self):
'''
Returns optional list of compiler flags to forward to the build
'''
return []
def is_compatible(self):
'''
Check if all non-python dependencies are satisfied to build this op
'''
return True
def python_requirements(self):
'''
Override if op wants to define special dependencies, otherwise will
take self.name and load requirements-<op-name>.txt if it exists.
'''
path = f'requirements/requirements-{self.name}.txt'
requirements = []
if os.path.isfile(path):
with open(path, 'r') as fd:
requirements = [r.strip() for r in fd.readlines()]
return requirements
def command_exists(self, cmd):
if '|' in cmd:
cmds = cmd.split("|")
else:
cmds = [cmd]
valid = False
for cmd in cmds:
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
valid = valid or result.wait() == 0
if not valid and len(cmds) > 1:
print(
f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!"
)
elif not valid and len(cmds) == 1:
print(
f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!"
)
return valid
def warning(self, msg):
print(f"{WARNING} {msg}")
def deepspeed_src_path(self, code_path):
if os.path.isabs(code_path):
return code_path
else:
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
def builder(self):
from torch.utils.cpp_extension import CppExtension
return CppExtension(name=self.absolute_name(),
sources=self.sources(),
include_dirs=self.include_paths(),
extra_compile_args={'cxx': self.cxx_args()})
def load(self, verbose=True):
from ...git_version_info import installed_ops, torch_info
if installed_ops[self.name]:
# Ensure the op we're about to load was compiled with the same
# torch/cuda versions we are currently using at runtime.
if isinstance(self, CUDAOpBuilder):
assert_torch_info(torch_info)
return importlib.import_module(self.absolute_name())
else:
return self.jit_load(verbose)
def jit_load(self, verbose=True):
if not self.is_compatible():
raise RuntimeError(
f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue."
)
try:
import ninja
except ImportError:
raise RuntimeError(
f"Unable to JIT load the {self.name} op due to ninja not being installed."
)
if isinstance(self, CUDAOpBuilder):
assert_no_cuda_mismatch()
self.jit_mode = True
from torch.utils.cpp_extension import load
# Ensure directory exists to prevent race condition in some cases
ext_path = os.path.join(
os.environ.get('TORCH_EXTENSIONS_DIR',
DEFAULT_TORCH_EXTENSION_PATH),
self.name)
os.makedirs(ext_path, exist_ok=True)
start_build = time.time()
op_module = load(
name=self.name,
sources=[self.deepspeed_src_path(path) for path in self.sources()],
extra_include_paths=[
self.deepspeed_src_path(path) for path in self.include_paths()
],
extra_cflags=self.cxx_args(),
extra_cuda_cflags=self.nvcc_args(),
verbose=verbose)
build_duration = time.time() - start_build
if verbose:
print(f"Time to load {self.name} op: {build_duration} seconds")
return op_module
class CUDAOpBuilder(OpBuilder):
def compute_capability_args(self, cross_compile_archs=['60', '61', '70']):
args = []
if self.jit_mode:
# Compile for underlying architecture since we know it at runtime
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability()
compute_capability = f"{CC_MAJOR}{CC_MINOR}"
args.append('-gencode')
args.append(
f'arch=compute_{compute_capability},code=compute_{compute_capability}')
else:
# Cross-compile mode, compile for various architectures
for compute_capability in cross_compile_archs:
args.append('-gencode')
args.append(
f'arch=compute_{compute_capability},code=compute_{compute_capability}'
)
return args
def version_dependent_macros(self):
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
version_ge_1_3 = ['-DVERSION_GE_1_3']
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
def is_compatible(self):
return super().is_compatible()
def builder(self):
from torch.utils.cpp_extension import CUDAExtension
assert_no_cuda_mismatch()
return CUDAExtension(name=self.absolute_name(),
sources=self.sources(),
include_dirs=self.include_paths(),
extra_compile_args={
'cxx': self.cxx_args(),
'nvcc': self.nvcc_args()
})

77
op_builder/cpu_adam.py Normal file
View File

@ -0,0 +1,77 @@
import os
import torch
import warnings
from .builder import CUDAOpBuilder
class CPUAdamBuilder(CUDAOpBuilder):
BUILD_VAR = "DS_BUILD_CPU_ADAM"
NAME = "cpu_adam"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'
def sources(self):
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/custom_cuda_kernel.cu']
def include_paths(self):
CUDA_INCLUDE = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")
return ['csrc/includes', CUDA_INCLUDE]
def available_vector_instructions(self):
try:
import cpufeature
except ImportError:
warnings.warn(
f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}
cpu_vector_instructions = {}
try:
cpu_vector_instructions = cpufeature.CPUFeature
except _:
warnings.warn(
f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}
return cpu_vector_instructions
def cxx_args(self):
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
cpu_info = self.available_vector_instructions()
SIMD_WIDTH = ''
if 'Intel' in cpu_info.get('VendorId', ''):
if cpu_info.get('AVX512f', False):
SIMD_WIDTH = '-D__AVX512__'
elif cpu_info.get('AVX2', False):
SIMD_WIDTH = '-D__AVX256__'
return [
'-O3',
'-std=c++14',
f'-L{CUDA_LIB64}',
'-lcudart',
'-lcublas',
'-g',
'-Wno-reorder',
'-march=native',
'-fopenmp',
SIMD_WIDTH
]
def nvcc_args(self):
args = [
'-O3',
'--use_fast_math',
'-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__'
]
args += self.compute_capability_args()
return args

25
op_builder/fused_adam.py Normal file
View File

@ -0,0 +1,25 @@
import torch
from .builder import CUDAOpBuilder
class FusedAdamBuilder(CUDAOpBuilder):
BUILD_VAR = "DS_BUILD_FUSED_ADAM"
NAME = "fused_adam"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'
def sources(self):
return ['csrc/adam/fused_adam_frontend.cpp', 'csrc/adam/multi_tensor_adam.cu']
def include_paths(self):
return ['csrc/includes']
def cxx_args(self):
return ['-O3'] + self.version_dependent_macros()
def nvcc_args(self):
return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros()

25
op_builder/fused_lamb.py Normal file
View File

@ -0,0 +1,25 @@
import torch
from .builder import CUDAOpBuilder
class FusedLambBuilder(CUDAOpBuilder):
BUILD_VAR = 'DS_BUILD_FUSED_LAMB'
NAME = "fused_lamb"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.lamb.{self.NAME}_op'
def sources(self):
return ['csrc/lamb/fused_lamb_cuda.cpp', 'csrc/lamb/fused_lamb_cuda_kernel.cu']
def include_paths(self):
return ['csrc/includes']
def cxx_args(self):
return ['-O3'] + self.version_dependent_macros()
def nvcc_args(self):
return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros()

36
op_builder/sparse_attn.py Normal file
View File

@ -0,0 +1,36 @@
import torch
import warnings
from .builder import OpBuilder
class SparseAttnBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_SPARSE_ATTN"
NAME = "sparse_attn"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.sparse_attention.{self.NAME}_op'
def sources(self):
return ['csrc/sparse_attention/utils.cpp']
def cxx_args(self):
return ['-O2', '-fopenmp']
def is_compatible(self):
# Check to see if llvm and cmake are installed since they are dependencies
required_commands = ['llvm-config|llvm-config-9', 'cmake']
command_status = list(map(self.command_exists, required_commands))
deps_compatible = all(command_status)
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5
if not torch_compatible:
self.warning(
f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
)
return super().is_compatible() and deps_compatible and torch_compatible

View File

@ -0,0 +1,18 @@
import torch
from .transformer import TransformerBuilder
class StochasticTransformerBuilder(TransformerBuilder):
BUILD_VAR = "DS_BUILD_STOCHASTIC_TRANSFORMER"
NAME = "stochastic_transformer"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.transformer.{self.NAME}_op'
def nvcc_args(self):
args = super().nvcc_args()
args.append('-D__STOCHASTIC_MODE__')
return args

44
op_builder/transformer.py Normal file
View File

@ -0,0 +1,44 @@
import torch
from .builder import CUDAOpBuilder
class TransformerBuilder(CUDAOpBuilder):
BUILD_VAR = "DS_BUILD_TRANSFORMER"
NAME = "transformer"
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
def absolute_name(self):
return f'deepspeed.ops.transformer.{self.NAME}_op'
def sources(self):
return [
'csrc/transformer/ds_transformer_cuda.cpp',
'csrc/transformer/cublas_wrappers.cu',
'csrc/transformer/transform_kernels.cu',
'csrc/transformer/gelu_kernels.cu',
'csrc/transformer/dropout_kernels.cu',
'csrc/transformer/normalize_kernels.cu',
'csrc/transformer/softmax_kernels.cu',
'csrc/transformer/general_kernels.cu'
]
def include_paths(self):
return ['csrc/includes']
def nvcc_args(self):
args = [
'-O3',
'--use_fast_math',
'-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__'
]
return args + self.compute_capability_args()
def cxx_args(self):
return ['-O3', '-std=c++14', '-g', '-Wno-reorder']

15
op_builder/utils.py Normal file
View File

@ -0,0 +1,15 @@
from .builder import OpBuilder
class UtilsBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_UTILS"
NAME = "utils"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.{self.NAME}_op'
def sources(self):
return ['csrc/utils/flatten_unflatten.cpp']

View File

@ -2,5 +2,6 @@ torch>=1.2
torchvision>=0.4.0
tqdm
psutil
cpufeature
tensorboardX==1.8
ninja
cpufeature

365
setup.py
View File

@ -16,7 +16,7 @@ import warnings
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension
VERSION = "0.3.0"
import op_builder
def fetch_requirements(path):
@ -24,88 +24,33 @@ def fetch_requirements(path):
return [r.strip() for r in fd.readlines()]
def available_vector_instructions():
try:
import cpufeature
except ImportError:
warnings.warn(
f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}
cpu_vector_instructions = {}
try:
cpu_vector_instructions = cpufeature.CPUFeature
except _:
warnings.warn(
f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}
return cpu_vector_instructions
install_requires = fetch_requirements('requirements/requirements.txt')
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
extras_require = {
'1bit_adam': fetch_requirements('requirements/requirements-1bit-adam.txt'),
'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'),
'dev': fetch_requirements('requirements/requirements-dev.txt'),
}
# If MPI is available add 1bit-adam requirements
if torch.cuda.is_available():
if shutil.which('ompi_info') or shutil.which('mpiname'):
onebit_adam_requires = fetch_requirements(
'requirements/requirements-1bit-adam.txt')
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
install_requires += onebit_adam_requires
cupy = f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}"
extras_require['1bit_adam'].append(cupy)
# Constants for each op
LAMB = "lamb"
TRANSFORMER = "transformer"
SPARSE_ATTN = "sparse-attn"
CPU_ADAM = "cpu-adam"
cpu_vector_instructions = available_vector_instructions()
# Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10
DS_BUILD_SPARSE_ATTN_MASK = 100
DS_BUILD_CPU_ADAM_MASK = 1000
# Allow for build_cuda to turn on or off all ops
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS
# Set default of each op based on if build_cuda is set
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM', 0)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK
# Final effective mask is the bitwise OR of each op
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN
| DS_BUILD_CPU_ADAM)
install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, CPU_ADAM], False)
if BUILD_MASK & DS_BUILD_LAMB:
install_ops[LAMB] = True
if BUILD_MASK & DS_BUILD_CPU_ADAM:
install_ops[CPU_ADAM] = True
if BUILD_MASK & DS_BUILD_TRANSFORMER:
install_ops[TRANSFORMER] = True
if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
install_ops[SPARSE_ATTN] = True
if len(install_ops) == 0:
print("Building without any cuda/cpp extensions")
print(f'BUILD_MASK={BUILD_MASK}, install_ops={install_ops}')
# Make an [all] extra that installs all needed dependencies
all_extras = set()
for extra in extras_require.items():
for req in extra[1]:
all_extras.add(req)
extras_require['all'] = list(all_extras)
cmdclass = {}
# For any pre-installed ops force disable ninja
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
TORCH_MAJOR = torch.__version__.split('.')[0]
TORCH_MINOR = torch.__version__.split('.')[1]
if not torch.cuda.is_available():
# Fix to allow docker buils, similar to https://github.com/NVIDIA/apex/issues/486
@ -116,230 +61,118 @@ if not torch.cuda.is_available():
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
version_ge_1_3 = ['-DVERSION_GE_1_3']
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
SIMD_WIDTH = ''
if cpu_vector_instructions.get('AVX512f', False):
SIMD_WIDTH = '-D__AVX512__'
elif cpu_vector_instructions.get('AVX2', False):
SIMD_WIDTH = '-D__AVX256__'
print("SIMD_WIDTH = ", SIMD_WIDTH)
ext_modules = []
## Lamb ##
if BUILD_MASK & DS_BUILD_LAMB:
ext_modules.append(
CUDAExtension(name='deepspeed.ops.lamb.fused_lamb_cuda',
sources=[
'csrc/lamb/fused_lamb_cuda.cpp',
'csrc/lamb/fused_lamb_cuda_kernel.cu'
],
include_dirs=['csrc/includes'],
extra_compile_args={
'cxx': [
'-O3',
] + version_dependent_macros,
'nvcc': ['-O3',
'--use_fast_math'] + version_dependent_macros
}))
from op_builder import ALL_OPS
## Adam ##
if BUILD_MASK & DS_BUILD_CPU_ADAM:
ext_modules.append(
CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op',
sources=[
'csrc/adam/cpu_adam.cpp',
'csrc/adam/custom_cuda_kernel.cu',
],
include_dirs=['csrc/includes',
'/usr/local/cuda/include'],
extra_compile_args={
'cxx': [
'-O3',
'-std=c++14',
'-L/usr/local/cuda/lib64',
'-lcudart',
'-lcublas',
'-g',
'-Wno-reorder',
'-march=native',
'-fopenmp',
SIMD_WIDTH
],
'nvcc': [
'-O3',
'--use_fast_math',
'-gencode',
'arch=compute_61,code=compute_61',
'-gencode',
'arch=compute_70,code=compute_70',
'-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__'
]
}))
## Transformer ##
if BUILD_MASK & DS_BUILD_TRANSFORMER:
ext_modules.append(
CUDAExtension(name='deepspeed.ops.transformer.transformer_cuda',
sources=[
'csrc/transformer/ds_transformer_cuda.cpp',
'csrc/transformer/cublas_wrappers.cu',
'csrc/transformer/transform_kernels.cu',
'csrc/transformer/gelu_kernels.cu',
'csrc/transformer/dropout_kernels.cu',
'csrc/transformer/normalize_kernels.cu',
'csrc/transformer/softmax_kernels.cu',
'csrc/transformer/general_kernels.cu'
],
include_dirs=['csrc/includes'],
extra_compile_args={
'cxx': ['-O3',
'-std=c++14',
'-g',
'-Wno-reorder'],
'nvcc': [
'-O3',
'--use_fast_math',
'-gencode',
'arch=compute_61,code=compute_61',
'-gencode',
'arch=compute_60,code=compute_60',
'-gencode',
'arch=compute_70,code=compute_70',
'-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__'
]
}))
ext_modules.append(
CUDAExtension(name='deepspeed.ops.transformer.stochastic_transformer_cuda',
sources=[
'csrc/transformer/ds_transformer_cuda.cpp',
'csrc/transformer/cublas_wrappers.cu',
'csrc/transformer/transform_kernels.cu',
'csrc/transformer/gelu_kernels.cu',
'csrc/transformer/dropout_kernels.cu',
'csrc/transformer/normalize_kernels.cu',
'csrc/transformer/softmax_kernels.cu',
'csrc/transformer/general_kernels.cu'
],
include_dirs=['csrc/includes'],
extra_compile_args={
'cxx': ['-O3',
'-std=c++14',
'-g',
'-Wno-reorder'],
'nvcc': [
'-O3',
'--use_fast_math',
'-gencode',
'arch=compute_61,code=compute_61',
'-gencode',
'arch=compute_60,code=compute_60',
'-gencode',
'arch=compute_70,code=compute_70',
'-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__',
'-D__STOCHASTIC_MODE__'
]
}))
# Default to pre-install kernels to false so we rely on JIT
BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0))
print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")
def command_exists(cmd):
if '|' in cmd:
cmds = cmd.split("|")
else:
cmds = [cmd]
valid = False
for cmd in cmds:
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
valid = valid or result.wait() == 0
return valid
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
return result.wait() == 0
## Sparse transformer ##
if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
# Check to see if llvm and cmake are installed since they are dependencies
required_commands = ['llvm-config|llvm-config-9', 'cmake']
def op_enabled(op_name):
assert hasattr(ALL_OPS[op_name], 'BUILD_VAR'), \
f"{op_name} is missing BUILD_VAR field"
env_var = ALL_OPS[op_name].BUILD_VAR
return int(os.environ.get(env_var, BUILD_OP_DEFAULT))
command_status = list(map(command_exists, required_commands))
if not all(command_status):
zipped_status = list(zip(required_commands, command_status))
warnings.warn(
f'Missing non-python requirements, please install the missing packages: {zipped_status}'
)
warnings.warn(
'Skipping sparse attention installation due to missing required packages')
# remove from installed ops list
install_ops[SPARSE_ATTN] = False
elif TORCH_MAJOR == 1 and TORCH_MINOR >= 5:
ext_modules.append(
CppExtension(name='deepspeed.ops.sparse_attention.cpp_utils',
sources=['csrc/sparse_attention/utils.cpp'],
extra_compile_args={'cxx': ['-O2',
'-fopenmp']}))
# Add sparse attention requirements
install_requires += sparse_attn_requires
else:
warnings.warn('Unable to meet requirements to install sparse attention')
# remove from installed ops list
install_ops[SPARSE_ATTN] = False
# Add development requirements
install_requires += dev_requires
install_ops = dict.fromkeys(ALL_OPS.keys(), False)
for op_name, builder in ALL_OPS.items():
op_compatible = builder.is_compatible()
# If op is compatible update install reqs so it can potentially build/run later
if op_compatible:
reqs = builder.python_requirements()
install_requires += builder.python_requirements()
# If op install enabled, add builder to extensions
if op_enabled(op_name) and op_compatible:
install_ops[op_name] = op_enabled(op_name)
ext_modules.append(builder.builder())
compatible_ops = {op_name: op.is_compatible() for (op_name, op) in ALL_OPS.items()}
print(f'Install Ops={install_ops}')
# Write out version/git info
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
if command_exists('git'):
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode('utf-8').strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode('utf-8').strip()
if command_exists('git') and 'DS_BUILD_STRING' not in os.environ:
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode('utf-8').strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
git_branch = "unknown"
else:
git_hash = "unknown"
git_branch = "unknown"
print(f"version={VERSION}+{git_hash}, git_hash={git_hash}, git_branch={git_branch}")
# Parse the DeepSpeed version string from version.txt
version_str = open('version.txt', 'r').read().strip()
# Build specifiers like .devX can be added at install time. Otherwise, add the git hash.
# example: DS_BUILD_STR=".dev20201022" python setup.py sdist bdist_wheel
#version_str += os.environ.get('DS_BUILD_STRING', f'+{git_hash}')
# Building wheel for distribution, update version file
if 'DS_BUILD_STRING' in os.environ:
# Build string env specified, probably building for distribution
with open('build.txt', 'w') as fd:
fd.write(os.environ.get('DS_BUILD_STRING'))
version_str += os.environ.get('DS_BUILD_STRING')
elif os.path.isfile('build.txt'):
# build.txt exists, probably installing from distribution
with open('build.txt', 'r') as fd:
version_str += fd.read().strip()
else:
# None of the above, probably installing from source
version_str += f'+{git_hash}'
torch_version = ".".join([TORCH_MAJOR, TORCH_MINOR])
cuda_version = ".".join(torch.version.cuda.split('.')[:2])
torch_info = {"version": torch_version, "cuda_version": cuda_version}
print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}")
with open('deepspeed/git_version_info_installed.py', 'w') as fd:
fd.write(f"version='{VERSION}+{git_hash}'\n")
fd.write(f"version='{version_str}'\n")
fd.write(f"git_hash='{git_hash}'\n")
fd.write(f"git_branch='{git_branch}'\n")
fd.write(f"installed_ops={install_ops}\n")
fd.write(f"compatible_ops={compatible_ops}\n")
fd.write(f"torch_info={torch_info}\n")
print(f'install_requires={install_requires}')
print(f'compatible_ops={compatible_ops}')
print(f'ext_modules={ext_modules}')
setup(name='deepspeed',
version=f"{VERSION}+{git_hash}",
version=version_str,
description='DeepSpeed library',
author='DeepSpeed Team',
author_email='deepspeed@microsoft.com',
url='http://deepspeed.ai',
install_requires=install_requires,
extras_require=extras_require,
packages=find_packages(exclude=["docker",
"third_party",
"csrc"]),
package_data={'deepspeed.ops.sparse_attention.trsrc': ['*.tr']},
scripts=['bin/deepspeed',
'bin/deepspeed.pt',
'bin/ds',
'bin/ds_ssh'],
"third_party"]),
include_package_data=True,
scripts=[
'bin/deepspeed',
'bin/deepspeed.pt',
'bin/ds',
'bin/ds_ssh',
'bin/ds_report'
],
classifiers=[
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',

View File

@ -363,10 +363,18 @@ except ImportError:
self.variance_epsilon = eps
def forward(self, x):
pdtype = x.dtype
x = x.float()
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
return self.weight * x.to(pdtype) + self.bias
#def forward(self, x):
# u = x.mean(-1, keepdim=True)
# s = (x - u).pow(2).mean(-1, keepdim=True)
# x = (x - u) / torch.sqrt(s + self.variance_epsilon)
# return self.weight * x + self.bias
class BertEmbeddings(nn.Module):

View File

@ -12,6 +12,8 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.pipe.topology import *
PipeTopo = PipeDataParallelTopology
from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder
import argparse
import pytest
import json
@ -152,8 +154,8 @@ def checkpoint_correctness_verification(args,
compare_lr_scheduler_states(trained_model, loaded_model)
@pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
reason="lamb is not installed")
@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME],
reason="lamb is not compatible")
def test_checkpoint_unfused_optimizer(tmpdir):
config_dict = {
"train_batch_size": 2,
@ -264,11 +266,11 @@ def test_checkpoint_fused_optimizer(tmpdir):
'Adam'),
(2,
True,
'deepspeed_adam'),
'Adam'),
])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
config_dict = {
"train_batch_size": 2,
@ -320,14 +322,14 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
"Adam"),
(2,
True,
'deepspeed_adam'),
'Adam'),
])
def test_checkpoint_zero_no_optimizer(tmpdir,
zero_stage,
use_cpu_offload,
adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
config_dict = {
"train_batch_size": 2,
@ -385,11 +387,11 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
'Adam'),
(2,
True,
'deepspeed_adam'),
'Adam'),
])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
config_dict = {
"train_batch_size": 2,
@ -459,11 +461,11 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
'Adam'),
(2,
True,
'deepspeed_adam'),
'Adam'),
])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
config_dict = {
"train_batch_size": 2,

View File

@ -1,16 +1,16 @@
import argparse
import torch
import apex
import time
import numpy as np
import pytest
import copy
import deepspeed
if not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed", allow_module_level=True)
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.op_builder import CPUAdamBuilder
if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
def check_equal(first, second, atol=1e-2, verbose=False):
@ -32,6 +32,7 @@ def check_equal(first, second, atol=1e-2, verbose=False):
(1048576),
]) # yapf: disable
def test_cpu_adam_opt(model_size):
from deepspeed.ops.adam import DeepSpeedCPUAdam
device = 'cpu'
rng_state = torch.get_rng_state()
param = torch.nn.Parameter(torch.randn(model_size, device=device))
@ -42,7 +43,7 @@ def test_cpu_adam_opt(model_size):
param2 = torch.nn.Parameter(param2_data)
optimizer1 = torch.optim.AdamW([param1])
optimizer2 = apex.optimizers.FusedAdam([param2])
optimizer2 = FusedAdam([param2])
optimizer = DeepSpeedCPUAdam([param])
for i in range(10):

View File

@ -16,8 +16,8 @@ import deepspeed
import sys
if not deepspeed.ops.__installed_ops__['transformer']:
pytest.skip("transformer kernels are not installed", allow_module_level=True)
#if not deepspeed.ops.__installed_ops__['transformer']:
# pytest.skip("transformer kernels are not installed", allow_module_level=True)
def check_equal(first, second, atol=1e-2, verbose=False):
@ -254,6 +254,7 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
check_equal(base_grads, ds_grads, atol=atol, verbose=verbose)
#test_backward[3-1024-120-16-24-True-True-0.05]
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[
(3,1024,120,16,24,True,False, 0.05),

View File

@ -16,8 +16,8 @@ import deepspeed
import sys
if not deepspeed.ops.__installed_ops__['transformer']:
pytest.skip("transformer kernels are not installed", allow_module_level=True)
#if not deepspeed.ops.__installed_ops__['transformer']:
# pytest.skip("transformer kernels are not installed", allow_module_level=True)
def check_equal(first, second, atol=1e-2, verbose=False):

View File

@ -8,9 +8,6 @@ import numpy as np
from common import distributed_test
from simple_model import SimpleModel, args_from_dict
lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
reason="lamb is not installed")
def run_model_step(model, gradient_list):
for value in gradient_list:
@ -168,7 +165,6 @@ def test_fused_some_overflow(tmpdir):
_test_fused_some_overflow(args)
@lamb_available
def test_unfused_no_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
@ -212,7 +208,6 @@ def test_unfused_no_overflow(tmpdir):
_test_unfused_no_overflow(args)
@lamb_available
def test_unfused_all_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
@ -258,7 +253,6 @@ def test_unfused_all_overflow(tmpdir):
_test_unfused_all_overflow(args)
@lamb_available
def test_unfused_some_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,

View File

@ -1,18 +1,21 @@
import torch
import apex
import deepspeed
import argparse
import pytest
import json
import os
from deepspeed.ops.adam import FusedAdam
from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
reason="lamb is not installed")
try:
from apex import amp
_amp_available = True
except ImportError:
_amp_available = False
amp_available = pytest.mark.skip(_amp_available, reason="apex/amp is not installed")
@lamb_available
def test_lamb_fp32_grad_clip(tmpdir):
config_dict = {
"train_batch_size": 2,
@ -48,7 +51,6 @@ def test_lamb_fp32_grad_clip(tmpdir):
_test_lamb_fp32_grad_clip(args=args, model=model, hidden_dim=hidden_dim)
@lamb_available
def test_lamb_fp16_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
@ -86,7 +88,6 @@ def test_lamb_fp16_basic(tmpdir):
_test_lamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)
@lamb_available
def test_lamb_fp16_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 2,
@ -234,8 +235,8 @@ def test_adamw_fp16_empty_grad(tmpdir):
True),
])
def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed")
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
@ -302,8 +303,8 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo
True),
])
def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed")
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
@ -402,8 +403,8 @@ def test_zero_static_scale_deprecated_format(tmpdir):
True),
])
def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed")
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
@ -442,8 +443,8 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
True),
])
def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed")
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed")
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
@ -489,6 +490,7 @@ def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload):
_test_zero_empty_partition(args)
@amp_available
def test_adam_amp_basic(tmpdir):
config_dict = {"train_batch_size": 1, "steps_per_print": 1, "amp": {"enabled": True}}
args = args_from_dict(tmpdir, config_dict)
@ -514,7 +516,7 @@ def test_adam_amp_basic(tmpdir):
_test_adam_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
@lamb_available
@amp_available
def test_lamb_amp_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
@ -552,6 +554,7 @@ def test_lamb_amp_basic(tmpdir):
_test_lamb_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
@amp_available
def test_adam_amp_o2(tmpdir):
config_dict = {
"train_batch_size": 2,
@ -590,6 +593,7 @@ def test_adam_amp_o2(tmpdir):
_test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim)
@amp_available
def test_adam_amp_o2_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 2,
@ -630,11 +634,11 @@ def test_adam_amp_o2_empty_grad(tmpdir):
@pytest.mark.parametrize('zero_stage, optimizer_constructor',
[(1,
apex.optimizers.FusedAdam),
FusedAdam),
(2,
torch.optim.Adam),
(2,
apex.optimizers.FusedAdam)])
FusedAdam)])
def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_constructor):
config_dict = {
"train_batch_size": 2,

View File

@ -6,9 +6,11 @@
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import SparseAttnBuilder
if not deepspeed.ops.__installed_ops__['sparse-attn']:
pytest.skip("cpu-adam is not installed", allow_module_level=True)
if not deepspeed.ops.__compatible_ops__[SparseAttnBuilder.NAME]:
pytest.skip("sparse attention op is not compatible on this system",
allow_module_level=True)
def test_sparse_attention_module_availability():
@ -236,7 +238,7 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo
def _skip_on_cuda_compatability():
pytest.skip("Skip these tests for now until we get our docker image fixed.")
#pytest.skip("Skip these tests for now until we get our docker image fixed.")
if torch.cuda.get_device_capability()[0] != 7:
pytest.skip("needs compute capability 7; v100")
cuda_major = int(torch.version.cuda.split('.')[0]) * 10

1
third_party/apex vendored

Submodule third_party/apex deleted from 494f8ab3fc

1
version.txt Normal file
View File

@ -0,0 +1 @@
0.3.1