mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
DeepSpeed JIT op + PyPI support (#496)
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com> Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@ -10,6 +10,7 @@ build/
|
||||
dist/
|
||||
*.so
|
||||
deepspeed.egg-info/
|
||||
build.txt
|
||||
|
||||
# Website
|
||||
docs/_site/
|
||||
@ -23,3 +24,7 @@ docs/code-docs/build
|
||||
|
||||
# Testing data
|
||||
tests/unit/saved_checkpoint/
|
||||
|
||||
# Dev/IDE data
|
||||
.vscode
|
||||
.theia
|
||||
|
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,6 +1,3 @@
|
||||
[submodule "third_party/apex"]
|
||||
path = third_party/apex
|
||||
url = https://github.com/NVIDIA/apex.git
|
||||
[submodule "DeepSpeedExamples"]
|
||||
path = DeepSpeedExamples
|
||||
url = https://github.com/microsoft/DeepSpeedExamples
|
||||
|
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@ -0,0 +1 @@
|
||||
global-include *.cpp *.h *.cu *.tr *.cuh *.cc *.txt
|
47
README.md
47
README.md
@ -1,4 +1,5 @@
|
||||
[](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master)
|
||||
[](https://badge.fury.io/py/deepspeed)
|
||||
[](https://deepspeed.readthedocs.io/en/latest/?badge=latest)
|
||||
[](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
|
||||
[](https://hub.docker.com/r/deepspeed/deepspeed)
|
||||
@ -31,29 +32,25 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
|
||||
|
||||
|
||||
# News
|
||||
* [2020/09/10] [DeepSpeed: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
|
||||
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
|
||||
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
|
||||
* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
|
||||
* [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
|
||||
* [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html)
|
||||
* [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)
|
||||
* [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html)
|
||||
* [2020/08/07] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) is now available on-demand
|
||||
* [2020/07/24] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) on August 6th, 2020
|
||||
[](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-Live.html)
|
||||
* [2020/05/19] [ZeRO-2 & DeepSpeed: Shattering Barriers of Deep Learning Speed & Scale](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/)
|
||||
* [2020/05/19] [An Order-of-Magnitude Larger and Faster Training with ZeRO-2](https://www.deepspeed.ai/news/2020/05/18/zero-stage2.html)
|
||||
* [2020/05/19] [The Fastest and Most Efficient BERT Training through Optimized Transformer Kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
|
||||
* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
|
||||
* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
|
||||
|
||||
|
||||
# Table of Contents
|
||||
| Section | Description |
|
||||
| --------------------------------------- | ------------------------------------------- |
|
||||
| [Why DeepSpeed?](#why-deepspeed) | DeepSpeed overview |
|
||||
| [Features](#features) | DeepSpeed features |
|
||||
| [Further Reading](#further-reading) | DeepSpeed documentation, tutorials, etc. |
|
||||
| [Contributing](#contributing) | Instructions for contributing to DeepSpeed |
|
||||
| [Publications](#publications) | DeepSpeed publications |
|
||||
| [Install](#installation) | Installation details |
|
||||
| [Features](#features) | Feature list and overview |
|
||||
| [Further Reading](#further-reading) | Documentation, tutorials, etc. |
|
||||
| [Contributing](#contributing) | Instructions for contributing |
|
||||
| [Publications](#publications) | Publications related to DeepSpeed |
|
||||
|
||||
# Why DeepSpeed?
|
||||
Training advanced deep learning models is challenging. Beyond model design,
|
||||
@ -65,8 +62,32 @@ a large model easily runs out of memory with pure data parallelism and it is
|
||||
difficult to use model parallelism. DeepSpeed addresses these challenges to
|
||||
accelerate model development *and* training.
|
||||
|
||||
# Features
|
||||
# Installation
|
||||
|
||||
The quickest way to get started with DeepSpeed is via pip, this will install
|
||||
the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA
|
||||
versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer
|
||||
to as our 'ops'. By default, all of these extensions/ops will be built
|
||||
just-in-time (JIT) using [torch's JIT C++ extension loader that relies on
|
||||
ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and
|
||||
dynamically link them at runtime.
|
||||
|
||||
```bash
|
||||
pip install deepspeed
|
||||
```
|
||||
|
||||
After installation you can validate your install and see which extensions/ops
|
||||
your machine is compatible with via the DeepSpeed environment report.
|
||||
|
||||
```bash
|
||||
ds_report
|
||||
```
|
||||
|
||||
If you would like to pre-install any of the DeepSpeed extensions/ops (instead
|
||||
of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced
|
||||
installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/).
|
||||
|
||||
# Features
|
||||
Below we provide a brief feature list, see our detailed [feature
|
||||
overview](https://www.deepspeed.ai/features/) for descriptions and usage.
|
||||
|
||||
|
@ -43,7 +43,6 @@ jobs:
|
||||
conda install -q --yes conda
|
||||
conda install -q --yes pip
|
||||
conda install -q --yes gxx_linux-64
|
||||
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
|
||||
echo "PATH=$PATH, LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
|
||||
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'
|
||||
|
||||
@ -51,9 +50,8 @@ jobs:
|
||||
- script: |
|
||||
source activate $(conda_env)
|
||||
pip install --progress-bar=off torch==$(pytorch.version) torchvision==$(torchvision.version)
|
||||
#-f https://download.pytorch.org/whl/torch_stable.html
|
||||
./install.sh --local_only
|
||||
#python -I basic_install_test.py
|
||||
pip install .[dev]
|
||||
ds_report
|
||||
displayName: 'Install DeepSpeed'
|
||||
|
||||
- script: |
|
||||
@ -71,7 +69,8 @@ jobs:
|
||||
|
||||
- script: |
|
||||
source activate $(conda_env)
|
||||
pytest --durations=0 --forked --verbose -x tests/unit/
|
||||
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
|
||||
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
|
||||
displayName: 'Unit tests'
|
||||
|
||||
# - script: |
|
||||
|
@ -1,65 +0,0 @@
|
||||
import torch
|
||||
import warnings
|
||||
import importlib
|
||||
import warnings
|
||||
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
END = '\033[0m'
|
||||
SUCCESS = f"{GREEN} [SUCCESS] {END}"
|
||||
WARNING = f"{YELLOW} [WARNING] {END}"
|
||||
FAIL = f'{RED} [FAIL] {END}'
|
||||
INFO = ' [INFO]'
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
print(f"{SUCCESS} deepspeed successfully imported.")
|
||||
except ImportError as err:
|
||||
raise err
|
||||
|
||||
print(f"{INFO} torch install path: {torch.__path__}")
|
||||
print(f"{INFO} torch version: {torch.__version__}, torch.cuda: {torch.version.cuda}")
|
||||
print(f"{INFO} deepspeed install path: {deepspeed.__path__}")
|
||||
print(
|
||||
f"{INFO} deepspeed info: {deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
|
||||
)
|
||||
|
||||
try:
|
||||
apex_C = importlib.import_module('apex_C')
|
||||
print(f"{SUCCESS} apex extensions successfully installed")
|
||||
except Exception as err:
|
||||
print(f'{WARNING} apex extensions are not installed')
|
||||
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
print(f"{INFO} using old-style apex")
|
||||
except ImportError:
|
||||
print(f"{INFO} using new-style apex")
|
||||
|
||||
try:
|
||||
importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
|
||||
print(f'{SUCCESS} fused lamb successfully installed.')
|
||||
except Exception as err:
|
||||
print(f"{WARNING} fused lamb is NOT installed.")
|
||||
|
||||
try:
|
||||
importlib.import_module('deepspeed.ops.transformer.transformer_cuda')
|
||||
print(f'{SUCCESS} transformer kernels successfully installed.')
|
||||
except Exception as err:
|
||||
print(f'{WARNING} transformer kernels are NOT installed.')
|
||||
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils')
|
||||
import triton
|
||||
print(f'{SUCCESS} sparse attention successfully installed.')
|
||||
except ImportError:
|
||||
print(f'{WARNING} sparse attention is NOT installed.')
|
||||
|
||||
try:
|
||||
importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
|
||||
print(f'{SUCCESS} cpu-adam (used by ZeRO-offload) successfully installed.')
|
||||
except ImportError:
|
||||
print(f'{WARNING} cpu-adam (used by ZeRO-offload) is NOT installed.')
|
6
bin/ds_report
Normal file
6
bin/ds_report
Normal file
@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from deepspeed.env_report import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
14
csrc/adam/compat.h
Normal file
14
csrc/adam/compat.h
Normal file
@ -0,0 +1,14 @@
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
19
csrc/adam/custom_cuda_kernel.cu
Normal file → Executable file
19
csrc/adam/custom_cuda_kernel.cu
Normal file → Executable file
@ -4,30 +4,15 @@
|
||||
|
||||
__global__ void param_update_kernel(const float* input, __half* output, int size)
|
||||
{
|
||||
const float4* input_cast = reinterpret_cast<const float4*>(input);
|
||||
float2* output_cast = reinterpret_cast<float2*>(output);
|
||||
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (id < size) {
|
||||
float4 data = input_cast[id];
|
||||
float2 cast_data;
|
||||
__half* output_h = reinterpret_cast<__half*>(&cast_data);
|
||||
|
||||
output_h[0] = (__half)data.x;
|
||||
output_h[1] = (__half)data.y;
|
||||
output_h[2] = (__half)data.z;
|
||||
output_h[3] = (__half)data.w;
|
||||
|
||||
output_cast[id] = cast_data;
|
||||
}
|
||||
if (id < size) { output[id] = (__half)input[id]; }
|
||||
}
|
||||
|
||||
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
|
||||
{
|
||||
int threads = 512;
|
||||
int threads = 1024;
|
||||
|
||||
size /= 4;
|
||||
dim3 grid_dim((size - 1) / threads + 1);
|
||||
dim3 block_dim(threads);
|
||||
|
||||
|
20
csrc/adam/fused_adam_frontend.cpp
Normal file
20
csrc/adam/fused_adam_frontend.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("multi_tensor_adam",
|
||||
&multi_tensor_adam_cuda,
|
||||
"Compute and apply gradient update to parameters for Adam optimizer");
|
||||
}
|
163
csrc/adam/multi_tensor_adam.cu
Normal file
163
csrc/adam/multi_tensor_adam.cu
Normal file
@ -0,0 +1,163 @@
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
typedef enum {
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
|
||||
} adamMode_t;
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct AdamFunctor {
|
||||
__device__ __forceinline__ void operator()(int chunk_size,
|
||||
volatile int* noop_gmem,
|
||||
TensorListMetadata<4>& tl,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float beta1_correction,
|
||||
const float beta2_correction,
|
||||
const float epsilon,
|
||||
const float lr,
|
||||
adamMode_t mode,
|
||||
const float decay)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
||||
// potentially use to pass in list of scalar
|
||||
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T* g = (T*)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T* p = (T*)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T* m = (T*)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T* v = (T*)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_g[ii] = g[i];
|
||||
r_p[ii] = p[i];
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
} else {
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (mode == ADAM_MODE_0) { // L2
|
||||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = next_m_unbiased / denom;
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
} else { // weight decay
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
p[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay)
|
||||
{
|
||||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
|
||||
0,
|
||||
"adam",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
lr,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
127
csrc/adam/multi_tensor_apply.cuh
Normal file
127
csrc/adam/multi_tensor_apply.cuh
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "compat.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// #include <iostream>
|
||||
|
||||
// This header is the one-stop shop for all your multi-tensor apply needs.
|
||||
|
||||
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
|
||||
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
void* addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
__global__ void multi_tensor_apply_kernel(int chunk_size,
|
||||
volatile int* noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
// Hand the chunk information to the user-supplied functor to process however it likes.
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(int block_size,
|
||||
int chunk_size,
|
||||
const at::Tensor& noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++) {
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
#ifdef VERSION_GE_1_5
|
||||
contiguous_memory = (contiguous_memory ||
|
||||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
#endif
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
|
||||
"A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk) {
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1) {
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
|
||||
// std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
} else {
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
|
||||
// std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
25
csrc/utils/flatten_unflatten.cpp
Normal file
25
csrc/utils/flatten_unflatten.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
/*
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#include <torch/csrc/utils/tensor_flatten.h>
|
||||
#include <torch/extension.h>
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
|
||||
|
||||
at::Tensor flatten(std::vector<at::Tensor> tensors)
|
||||
{
|
||||
return torch::utils::flatten_dense_tensors(tensors);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
|
||||
{
|
||||
return torch::utils::unflatten_dense_tensors(flat, tensors);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("flatten", &flatten, "Flatten dense tensors");
|
||||
m.def("unflatten", &unflatten, "Unflatten dense tensors");
|
||||
}
|
@ -17,24 +17,19 @@ from .utils import log_dist
|
||||
|
||||
from .pipe import PipelineModule
|
||||
|
||||
try:
|
||||
from .git_version_info import version, git_hash, git_branch
|
||||
except ImportError:
|
||||
version = "0.0.0+unknown"
|
||||
git_hash = None
|
||||
git_branch = None
|
||||
from .git_version_info import version, git_hash, git_branch
|
||||
|
||||
|
||||
def _parse_version(version_str):
|
||||
'''Parse a version string and extract the major, minor, and patch versions.'''
|
||||
import re
|
||||
matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str)
|
||||
return int(matched.group(1)), int(matched.group(2)), int(matched.group(3))
|
||||
|
||||
|
||||
# Export version information
|
||||
version, __version_tag__ = version.split('+')
|
||||
__version_major__ = int(version.split('.')[0])
|
||||
__version_minor__ = int(version.split('.')[1])
|
||||
__version_patch__ = int(version.split('.')[2])
|
||||
__version__ = '.'.join(
|
||||
map(str,
|
||||
[__version_major__,
|
||||
__version_minor__,
|
||||
__version_patch__]))
|
||||
__version__ = f"{__version__}+{__version_tag__}"
|
||||
__version__ = version
|
||||
__version_major__, __version_minor__, __version_patch__ = _parse_version(__version__)
|
||||
__git_hash__ = git_hash
|
||||
__git_branch__ = git_branch
|
||||
|
||||
|
107
deepspeed/env_report.py
Normal file
107
deepspeed/env_report.py
Normal file
@ -0,0 +1,107 @@
|
||||
import torch
|
||||
import deepspeed
|
||||
import subprocess
|
||||
from .ops.op_builder import ALL_OPS
|
||||
from .git_version_info import installed_ops, torch_info
|
||||
from .ops import __compatible_ops__ as compatible_ops
|
||||
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
END = '\033[0m'
|
||||
SUCCESS = f"{GREEN} [SUCCESS] {END}"
|
||||
OKAY = f"{GREEN}[OKAY]{END}"
|
||||
WARNING = f"{YELLOW}[WARNING]{END}"
|
||||
FAIL = f'{RED}[FAIL]{END}'
|
||||
INFO = '[INFO]'
|
||||
|
||||
color_len = len(GREEN) + len(END)
|
||||
okay = f"{GREEN}[OKAY]{END}"
|
||||
warning = f"{YELLOW}[WARNING]{END}"
|
||||
|
||||
|
||||
def op_report():
|
||||
max_dots = 23
|
||||
max_dots2 = 11
|
||||
h = ["op name", "installed", "compatible"]
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
print("DeepSpeed C++/CUDA extension op report")
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
|
||||
print("NOTE: Ops not installed will be just-in-time (JIT) compiled at\n"
|
||||
" runtime if needed. Op compatibility means that your system\n"
|
||||
" meet the required dependencies to JIT install the op.")
|
||||
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
print("JIT compiled ops requires ninja")
|
||||
ninja_status = OKAY if ninja_installed() else FAIL
|
||||
print('ninja', "." * (max_dots - 5), ninja_status)
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
print(h[0], "." * (max_dots - len(h[0])), h[1], "." * (max_dots2 - len(h[1])), h[2])
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
installed = f"{GREEN}[YES]{END}"
|
||||
no = f"{YELLOW}[NO]{END}"
|
||||
for op_name, builder in ALL_OPS.items():
|
||||
dots = "." * (max_dots - len(op_name))
|
||||
is_compatible = OKAY if builder.is_compatible() else no
|
||||
is_installed = installed if installed_ops[op_name] else no
|
||||
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) -
|
||||
(len(is_installed) - color_len))
|
||||
print(op_name, dots, is_installed, dots2, is_compatible)
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
|
||||
|
||||
def ninja_installed():
|
||||
try:
|
||||
import ninja
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def nvcc_version():
|
||||
import torch.utils.cpp_extension
|
||||
cuda_home = torch.utils.cpp_extension.CUDA_HOME
|
||||
try:
|
||||
output = subprocess.check_output([cuda_home + "/bin/nvcc",
|
||||
"-V"],
|
||||
universal_newlines=True)
|
||||
except FileNotFoundError:
|
||||
return f"{RED} [FAIL] nvcc missing {END}"
|
||||
output_split = output.split()
|
||||
release_idx = output_split.index("release")
|
||||
release = output_split[release_idx + 1].replace(',', '').split(".")
|
||||
return ".".join(release)
|
||||
|
||||
|
||||
def debug_report():
|
||||
max_dots = 33
|
||||
report = [
|
||||
("torch install path",
|
||||
torch.__path__),
|
||||
("torch version",
|
||||
torch.__version__),
|
||||
("torch cuda version",
|
||||
torch.version.cuda),
|
||||
("nvcc version",
|
||||
nvcc_version()),
|
||||
("deepspeed install path",
|
||||
deepspeed.__path__),
|
||||
("deepspeed info",
|
||||
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
|
||||
),
|
||||
("deepspeed wheel compiled w.",
|
||||
f"torch {torch_info['version']}, cuda {torch_info['cuda_version']}"),
|
||||
]
|
||||
print("DeepSpeed general environment info:")
|
||||
for name, value in report:
|
||||
print(name, "." * (max_dots - len(name)), value)
|
||||
|
||||
|
||||
def main():
|
||||
op_report()
|
||||
debug_report()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -3,12 +3,11 @@ try:
|
||||
from .git_version_info_installed import *
|
||||
except ModuleNotFoundError:
|
||||
# Will be missing from checkouts that haven't been installed (e.g., readthedocs)
|
||||
version = '0.3.0+[none]'
|
||||
version = open('version.txt', 'r').read().strip()
|
||||
git_hash = '[none]'
|
||||
git_branch = '[none]'
|
||||
installed_ops = {
|
||||
'lamb': False,
|
||||
'transformer': False,
|
||||
'sparse-attn': False,
|
||||
'cpu-adam': False
|
||||
}
|
||||
|
||||
from .ops.op_builder import ALL_OPS
|
||||
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
torch_info = {'version': "0.0", "cuda_version": "0.0"}
|
||||
|
@ -1,7 +1,6 @@
|
||||
from ..git_version_info import installed_ops as __installed_ops__
|
||||
from . import adam
|
||||
from . import lamb
|
||||
from . import sparse_attention
|
||||
from . import transformer
|
||||
if __installed_ops__['sparse-attn']:
|
||||
from . import sparse_attention
|
||||
if __installed_ops__['cpu-adam']:
|
||||
from . import adam
|
||||
|
||||
from ..git_version_info import compatible_ops as __compatible_ops__
|
||||
|
@ -1 +1,2 @@
|
||||
from .cpu_adam import DeepSpeedCPUAdam
|
||||
from .fused_adam import FusedAdam
|
||||
|
@ -4,9 +4,9 @@ Copyright 2020 The Microsoft DeepSpeed Team
|
||||
|
||||
import math
|
||||
import torch
|
||||
import importlib
|
||||
|
||||
ds_opt_adam = None
|
||||
import time
|
||||
from pathlib import Path
|
||||
from ..op_builder import CPUAdamBuilder
|
||||
|
||||
|
||||
class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
@ -67,15 +67,15 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
self.opt_id = DeepSpeedCPUAdam.optimizer_id
|
||||
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
|
||||
|
||||
global ds_opt_adam
|
||||
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
|
||||
ds_opt_adam.create_adam(self.opt_id,
|
||||
lr,
|
||||
betas[0],
|
||||
betas[1],
|
||||
eps,
|
||||
weight_decay,
|
||||
adamw_mode)
|
||||
self.ds_opt_adam = CPUAdamBuilder().load()
|
||||
|
||||
self.ds_opt_adam.create_adam(self.opt_id,
|
||||
lr,
|
||||
betas[0],
|
||||
betas[1],
|
||||
eps,
|
||||
weight_decay,
|
||||
adamw_mode)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(DeepSpeedCPUAdam, self).__setstate__(state)
|
||||
@ -101,18 +101,20 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
print(f'group {group_id} param {param_id} = {p.numel()}')
|
||||
state['step'] = 0
|
||||
# gradient momentums
|
||||
state['exp_avg'] = torch.zeros_like(
|
||||
p.data,
|
||||
memory_format=torch.preserve_format)
|
||||
state['exp_avg'] = torch.zeros_like(p.data,
|
||||
dtype=p.dtype,
|
||||
device='cpu')
|
||||
#memory_format=torch.preserve_format)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(
|
||||
p.data,
|
||||
memory_format=torch.preserve_format)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data,
|
||||
dtype=p.dtype,
|
||||
device='cpu')
|
||||
#memory_format=torch.preserve_format)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
if fp16_param_groups is not None:
|
||||
ds_opt_adam.adam_update_copy(
|
||||
self.ds_opt_adam.adam_update_copy(
|
||||
self.opt_id,
|
||||
state['step'],
|
||||
group['lr'],
|
||||
@ -122,11 +124,11 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
state['exp_avg_sq'],
|
||||
fp16_param_groups[group_id][param_id].data)
|
||||
else:
|
||||
ds_opt_adam.adam_update(self.opt_id,
|
||||
state['step'],
|
||||
group['lr'],
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state['exp_avg'],
|
||||
state['exp_avg_sq'])
|
||||
self.ds_opt_adam.adam_update(self.opt_id,
|
||||
state['step'],
|
||||
group['lr'],
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state['exp_avg'],
|
||||
state['exp_avg_sq'])
|
||||
return loss
|
||||
|
182
deepspeed/ops/adam/fused_adam.py
Normal file
182
deepspeed/ops/adam/fused_adam.py
Normal file
@ -0,0 +1,182 @@
|
||||
'''
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
'''
|
||||
|
||||
import torch
|
||||
import importlib
|
||||
from .multi_tensor_apply import MultiTensorApply
|
||||
multi_tensor_applier = MultiTensorApply(2048 * 32)
|
||||
from ..op_builder import FusedAdamBuilder
|
||||
|
||||
|
||||
class FusedAdam(torch.optim.Optimizer):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
Currently GPU-only.
|
||||
|
||||
This version of fused Adam implements 2 fusions.
|
||||
|
||||
* Fusion of the Adam update's elementwise operations
|
||||
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
|
||||
|
||||
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square. (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability. (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False) NOT SUPPORTED in FusedAdam!
|
||||
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
|
||||
.. _Adam - A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9,
|
||||
0.999),
|
||||
eps=1e-8,
|
||||
adam_w_mode=True,
|
||||
weight_decay=0.,
|
||||
amsgrad=False,
|
||||
set_grad_none=True):
|
||||
|
||||
if amsgrad:
|
||||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr,
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay)
|
||||
super(FusedAdam, self).__init__(params, defaults)
|
||||
self.adam_w_mode = 1 if adam_w_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
|
||||
fused_adam_cuda = FusedAdamBuilder().load()
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self.multi_tensor_adam = fused_adam_cuda.multi_tensor_adam
|
||||
|
||||
def zero_grad(self):
|
||||
if self.set_grad_none:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
p.grad = None
|
||||
else:
|
||||
super(FusedAdam, self).zero_grad()
|
||||
|
||||
def step(self,
|
||||
closure=None,
|
||||
grads=None,
|
||||
output_params=None,
|
||||
scale=None,
|
||||
grad_norms=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
|
||||
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
|
||||
"""
|
||||
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
|
||||
raise RuntimeError(
|
||||
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
|
||||
)
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# assume same step across group now to simplify things
|
||||
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
||||
if 'step' in group:
|
||||
group['step'] += 1
|
||||
else:
|
||||
group['step'] = 1
|
||||
|
||||
# create lists for multi-tensor apply
|
||||
g_16, p_16, m_16, v_16 = [], [], [], []
|
||||
g_32, p_32, m_32, v_32 = [], [], [], []
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.data.is_sparse:
|
||||
raise RuntimeError(
|
||||
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
if p.dtype == torch.float16:
|
||||
g_16.append(p.grad.data)
|
||||
p_16.append(p.data)
|
||||
m_16.append(state['exp_avg'])
|
||||
v_16.append(state['exp_avg_sq'])
|
||||
elif p.dtype == torch.float32:
|
||||
g_32.append(p.grad.data)
|
||||
p_32.append(p.data)
|
||||
m_32.append(state['exp_avg'])
|
||||
v_32.append(state['exp_avg_sq'])
|
||||
else:
|
||||
raise RuntimeError('FusedAdam only support fp16 and fp32.')
|
||||
|
||||
if (len(g_16) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_16,
|
||||
p_16,
|
||||
m_16,
|
||||
v_16],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self.adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
if (len(g_32) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_32,
|
||||
p_32,
|
||||
m_32,
|
||||
v_32],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self.adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
|
||||
return loss
|
15
deepspeed/ops/adam/multi_tensor_apply.py
Normal file
15
deepspeed/ops/adam/multi_tensor_apply.py
Normal file
@ -0,0 +1,15 @@
|
||||
'''
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from NVIDIA/apex, commit a109f85
|
||||
'''
|
||||
import torch
|
||||
|
||||
|
||||
class MultiTensorApply(object):
|
||||
def __init__(self, chunk_size):
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
|
||||
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
|
1
deepspeed/ops/csrc
Symbolic link
1
deepspeed/ops/csrc
Symbolic link
@ -0,0 +1 @@
|
||||
../../csrc
|
@ -1 +1 @@
|
||||
from deepspeed.ops.lamb.fused_lamb import FusedLamb
|
||||
from .fused_lamb import FusedLamb
|
||||
|
@ -5,8 +5,8 @@ Copyright NVIDIA/apex
|
||||
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
|
||||
'''
|
||||
import types
|
||||
import importlib
|
||||
import torch
|
||||
from ..op_builder import FusedLambBuilder
|
||||
|
||||
|
||||
class FusedLamb(torch.optim.Optimizer):
|
||||
@ -48,15 +48,7 @@ class FusedLamb(torch.optim.Optimizer):
|
||||
max_coeff=10.0,
|
||||
min_coeff=0.01,
|
||||
amsgrad=False):
|
||||
global fused_lamb_cuda
|
||||
try:
|
||||
fused_lamb_cuda = importlib.import_module(
|
||||
"deepspeed.ops.lamb.fused_lamb_cuda")
|
||||
except ImportError as err:
|
||||
print(
|
||||
"Unable to import Lamb cuda extension, please build DeepSpeed with cuda/cpp extensions."
|
||||
)
|
||||
raise err
|
||||
self.fused_lamb_cuda = FusedLambBuilder().load()
|
||||
|
||||
if amsgrad:
|
||||
raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
|
||||
@ -173,22 +165,22 @@ class FusedLamb(torch.optim.Optimizer):
|
||||
out_p = torch.tensor(
|
||||
[],
|
||||
dtype=torch.float) if output_param is None else output_param
|
||||
lamb_coeff = fused_lamb_cuda.lamb(p.data,
|
||||
out_p,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
grad,
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
max_coeff,
|
||||
min_coeff,
|
||||
group['eps'],
|
||||
combined_scale,
|
||||
state['step'],
|
||||
self.eps_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
lamb_coeff = self.fused_lamb_cuda.lamb(p.data,
|
||||
out_p,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
grad,
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
max_coeff,
|
||||
min_coeff,
|
||||
group['eps'],
|
||||
combined_scale,
|
||||
state['step'],
|
||||
self.eps_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
self.lamb_coeffs.append(lamb_coeff)
|
||||
return loss
|
||||
|
||||
|
1
deepspeed/ops/op_builder
Symbolic link
1
deepspeed/ops/op_builder
Symbolic link
@ -0,0 +1 @@
|
||||
../../op_builder
|
@ -2,13 +2,12 @@
|
||||
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
|
||||
import importlib
|
||||
import warnings
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
warnings.warn("Unable to import triton, sparse attention will not be accessible")
|
||||
import torch
|
||||
import math
|
||||
from deepspeed.ops.sparse_attention.trsrc import matmul
|
||||
from .trsrc import matmul
|
||||
from ..op_builder import SparseAttnBuilder
|
||||
|
||||
triton = None
|
||||
|
||||
|
||||
##############
|
||||
@ -27,6 +26,9 @@ class _sparse_matmul(torch.autograd.Function):
|
||||
# between `seg_size` elements
|
||||
@staticmethod
|
||||
def load_balance(sizes, block):
|
||||
global triton
|
||||
if triton is None:
|
||||
triton = importlib.import_module('triton')
|
||||
# segment size
|
||||
# heuristics taken from OpenAI blocksparse code
|
||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||
@ -83,11 +85,18 @@ class _sparse_matmul(torch.autograd.Function):
|
||||
##########################
|
||||
# SPARSE = DENSE x DENSE #
|
||||
##########################
|
||||
cpp_utils = importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils')
|
||||
sdd_segment = cpp_utils.sdd_segment
|
||||
cpp_utils = None
|
||||
sdd_segment = None
|
||||
|
||||
@staticmethod
|
||||
def _load_utils():
|
||||
if _sparse_matmul.cpp_utils is None:
|
||||
_sparse_matmul.cpp_utils = SparseAttnBuilder().load()
|
||||
_sparse_matmul.sdd_segment = _sparse_matmul.cpp_utils.sdd_segment
|
||||
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout, block, dtype, device):
|
||||
_sparse_matmul._load_utils()
|
||||
start_width = 64 // block
|
||||
segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
|
||||
luts, widths, packs = [], [], []
|
||||
@ -118,6 +127,10 @@ class _sparse_matmul(torch.autograd.Function):
|
||||
packs,
|
||||
bench,
|
||||
time):
|
||||
global triton
|
||||
if triton is None:
|
||||
triton = importlib.import_module('triton')
|
||||
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
@ -332,6 +345,10 @@ class _sparse_matmul(torch.autograd.Function):
|
||||
packs,
|
||||
bench,
|
||||
time):
|
||||
global triton
|
||||
if triton is None:
|
||||
triton = importlib.import_module('triton')
|
||||
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
@ -413,6 +430,10 @@ class _sparse_matmul(torch.autograd.Function):
|
||||
packs,
|
||||
bench,
|
||||
time):
|
||||
global triton
|
||||
if triton is None:
|
||||
triton = importlib.import_module('triton')
|
||||
|
||||
# shapes / dtypes
|
||||
AS0 = spdims[0]
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
|
@ -2,17 +2,17 @@
|
||||
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
|
||||
|
||||
import warnings
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
warnings.warn("Unable to import triton, sparse attention will not be accessible")
|
||||
import importlib
|
||||
import torch
|
||||
import math
|
||||
from deepspeed.ops.sparse_attention.trsrc import softmax_fwd, softmax_bwd
|
||||
from .trsrc import softmax_fwd, softmax_bwd
|
||||
|
||||
fwd_kernels = dict()
|
||||
bwd_kernels = dict()
|
||||
|
||||
# Delay importing triton unless we need it
|
||||
triton = None
|
||||
|
||||
|
||||
class _sparse_softmax(torch.autograd.Function):
|
||||
|
||||
@ -52,6 +52,10 @@ class _sparse_softmax(torch.autograd.Function):
|
||||
apply_attn_mask,
|
||||
kp_mask_mode,
|
||||
attn_mask_mode):
|
||||
global triton
|
||||
if triton is None:
|
||||
triton = importlib.import_module('triton')
|
||||
|
||||
if max_k >= 32768:
|
||||
raise NotImplementedError('Reductions larger than 32768 elements '\
|
||||
'are not yet implemented')
|
||||
@ -112,6 +116,10 @@ class _sparse_softmax(torch.autograd.Function):
|
||||
maxlut,
|
||||
bench,
|
||||
time):
|
||||
global triton
|
||||
if triton is None:
|
||||
triton = importlib.import_module('triton')
|
||||
|
||||
apply_scale = False if scale == 1.0 else True
|
||||
|
||||
# handle None rpe
|
||||
@ -180,6 +188,10 @@ class _sparse_softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dx):
|
||||
global triton
|
||||
if triton is None:
|
||||
triton = importlib.import_module('triton')
|
||||
|
||||
# retrieve from context
|
||||
x, lut = ctx.saved_tensors
|
||||
# run kernel
|
||||
|
@ -1 +1 @@
|
||||
from deepspeed.ops.transformer.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
|
||||
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
|
||||
|
@ -8,6 +8,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
from ..op_builder import TransformerBuilder, StochasticTransformerBuilder
|
||||
|
||||
# Cuda modules will be imported if needed
|
||||
transformer_cuda_module = None
|
||||
stochastic_transformer_cuda_module = None
|
||||
@ -483,19 +485,12 @@ class DeepSpeedTransformerLayer(nn.Module):
|
||||
self.norm_w = initial_weights[7]
|
||||
self.norm_b = initial_biases[7]
|
||||
|
||||
# Import cuda modules if needed
|
||||
# Load cuda modules if needed
|
||||
global transformer_cuda_module, stochastic_transformer_cuda_module
|
||||
if transformer_cuda_module is None or stochastic_transformer_cuda_module is None:
|
||||
try:
|
||||
transformer_cuda_module = importlib.import_module(
|
||||
"deepspeed.ops.transformer.transformer_cuda")
|
||||
stochastic_transformer_cuda_module = importlib.import_module(
|
||||
"deepspeed.ops.transformer.stochastic_transformer_cuda")
|
||||
except ImportError as err:
|
||||
print(
|
||||
"Unable to import transformer cuda extension, please build DeepSpeed with cuda/cpp extensions."
|
||||
)
|
||||
raise err
|
||||
if transformer_cuda_module is None and not self.config.stochastic_mode:
|
||||
transformer_cuda_module = TransformerBuilder().load()
|
||||
if stochastic_transformer_cuda_module is None and self.config.stochastic_mode:
|
||||
stochastic_transformer_cuda_module = StochasticTransformerBuilder().load()
|
||||
|
||||
# create the layer in cuda kernels.
|
||||
cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module
|
||||
|
@ -7,8 +7,6 @@ import torch
|
||||
import warnings
|
||||
import torch.distributed as dist
|
||||
|
||||
import apex
|
||||
from apex import amp
|
||||
from torch.nn.modules import Module
|
||||
from torch.distributed.distributed_c10d import _get_global_rank
|
||||
from tensorboardX import SummaryWriter
|
||||
@ -36,22 +34,17 @@ from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
|
||||
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
|
||||
|
||||
from .utils import ensure_directory_exists
|
||||
from ..ops.op_builder import UtilsBuilder
|
||||
from ..ops.adam import DeepSpeedCPUAdam
|
||||
from ..ops.adam import FusedAdam
|
||||
|
||||
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
|
||||
|
||||
try:
|
||||
from apex_C import flatten
|
||||
from apex_C import unflatten
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
try:
|
||||
_ = warned_flatten
|
||||
except NameError:
|
||||
logger.warning(
|
||||
"Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
|
||||
)
|
||||
warned_flatten = True
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
from torch._utils import _unflatten_dense_tensors as unflatten
|
||||
# Fail silently so we don't spam logs unnecessarily if user isn't using amp
|
||||
pass
|
||||
|
||||
|
||||
def split_half_float_double_csr(tensors):
|
||||
@ -201,6 +194,11 @@ class DeepSpeedEngine(Module):
|
||||
if self.dump_state():
|
||||
print_configuration(self, 'DeepSpeedEngine')
|
||||
|
||||
# Load pre-installed or JIT compile (un)flatten ops
|
||||
util_ops = UtilsBuilder().load()
|
||||
self.flatten = util_ops.flatten
|
||||
self.unflatten = util_ops.unflatten
|
||||
|
||||
def _mpi_check(self, args, dist_init_required):
|
||||
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
|
||||
from mpi4py import MPI
|
||||
@ -558,6 +556,12 @@ class DeepSpeedEngine(Module):
|
||||
amp_params = self.amp_params()
|
||||
if self.global_rank == 0:
|
||||
logger.info(f"Initializing AMP with these params: {amp_params}")
|
||||
try:
|
||||
logger.info("Initializing Apex amp from: {}".format(amp.__path__))
|
||||
except NameError:
|
||||
# If apex/amp is available it will be imported above
|
||||
raise RuntimeError(
|
||||
"Unable to import apex/amp, please make sure it is installed")
|
||||
self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
|
||||
self._broadcast_model()
|
||||
elif self.fp16_enabled():
|
||||
@ -584,17 +588,18 @@ class DeepSpeedEngine(Module):
|
||||
# T|F T F torch.optim.Adam
|
||||
# T F T|F DeepSpeedCPUAdam(adam_w_mode)
|
||||
# F F T|F FusedAdam(adam_w_mode)
|
||||
if torch_adam and adam_w_mode:
|
||||
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
|
||||
elif torch_adam and not adam_w_mode:
|
||||
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
|
||||
elif self.zero_cpu_offload() and not torch_adam:
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
if torch_adam:
|
||||
if adam_w_mode:
|
||||
optimizer = torch.optim.AdamW(model_parameters,
|
||||
**optimizer_parameters)
|
||||
else:
|
||||
optimizer = torch.optim.Adam(model_parameters,
|
||||
**optimizer_parameters)
|
||||
elif self.zero_cpu_offload():
|
||||
optimizer = DeepSpeedCPUAdam(model_parameters,
|
||||
**optimizer_parameters,
|
||||
adamw_mode=adam_w_mode)
|
||||
elif not self.zero_cpu_offload() and not torch_adam:
|
||||
from apex.optimizers.fused_adam import FusedAdam
|
||||
else:
|
||||
optimizer_parameters[ADAM_W_MODE_PARAM] = adam_w_mode
|
||||
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
|
||||
|
||||
@ -614,8 +619,7 @@ class DeepSpeedEngine(Module):
|
||||
dynamic_loss_args = self.dynamic_loss_scale_args()
|
||||
clip_grad = self.gradient_clipping()
|
||||
if isinstance(optimizer,
|
||||
apex.optimizers.FusedAdam) or self.optimizer_name(
|
||||
) == ONEBIT_ADAM_OPTIMIZER:
|
||||
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
|
||||
if self.dynamic_loss_scale():
|
||||
logger.info('Creating fp16 optimizer with dynamic loss scale')
|
||||
timers = self.timers if self.wall_clock_breakdown() else None
|
||||
@ -1072,7 +1076,7 @@ class DeepSpeedEngine(Module):
|
||||
ranks=[0])
|
||||
|
||||
def allreduce_bucket(self, bucket):
|
||||
tensor = flatten(bucket)
|
||||
tensor = self.flatten(bucket)
|
||||
|
||||
tensor_to_allreduce = tensor
|
||||
|
||||
@ -1100,7 +1104,7 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
def allreduce_and_copy(self, small_bucket):
|
||||
allreduced = self.allreduce_bucket(small_bucket)
|
||||
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
|
||||
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
|
||||
buf.copy_(synced)
|
||||
|
||||
def allreduce_no_retain(self, bucket, numel_per_bucket=500000000):
|
||||
|
@ -15,26 +15,15 @@ import collections
|
||||
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
|
||||
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
|
||||
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
from deepspeed.utils import logger
|
||||
from ...ops.op_builder import UtilsBuilder
|
||||
|
||||
#Toggle this to true to enable correctness test
|
||||
#with gradient partitioning and without
|
||||
pg_correctness_test = False
|
||||
|
||||
try:
|
||||
from apex_C import flatten
|
||||
from apex_C import unflatten
|
||||
except ImportError:
|
||||
try:
|
||||
_ = warned_flatten
|
||||
except NameError:
|
||||
logger.warning(
|
||||
"apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
|
||||
)
|
||||
warned_flatten = True
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
from torch._utils import _unflatten_dense_tensors as unflatten
|
||||
|
||||
|
||||
def input(msg):
|
||||
return
|
||||
@ -132,6 +121,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
|
||||
gradient_predivide_factor=1.0,
|
||||
gradient_accumulation_steps=1):
|
||||
|
||||
# Load pre-installed or JIT compile (un)flatten ops
|
||||
util_ops = UtilsBuilder().load()
|
||||
self.flatten = util_ops.flatten
|
||||
self.unflatten = util_ops.unflatten
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f"Reduce bucket size {reduce_bucket_size}")
|
||||
logger.info(f"Allgather bucket size {allgather_bucket_size}")
|
||||
@ -1053,7 +1047,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
|
||||
|
||||
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
|
||||
rank = None
|
||||
tensor = flatten(bucket)
|
||||
tensor = self.flatten(bucket)
|
||||
|
||||
tensor_to_allreduce = tensor
|
||||
|
||||
@ -1095,7 +1089,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
|
||||
with torch.cuda.stream(stream):
|
||||
allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
|
||||
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
|
||||
for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
|
||||
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
|
||||
buf.copy_(synced)
|
||||
|
||||
def allreduce_no_retain(self,
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import apex
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
|
||||
|
||||
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
|
||||
@ -23,11 +23,14 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
|
||||
return my_group
|
||||
|
||||
|
||||
ZERO_SUPPORTED_OPTIMIZERS = [
|
||||
torch.optim.Adam,
|
||||
apex.optimizers.FusedAdam,
|
||||
DeepSpeedCPUAdam
|
||||
]
|
||||
ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, FusedAdam, DeepSpeedCPUAdam]
|
||||
|
||||
# Add apex FusedAdam to supported list if apex is installed
|
||||
try:
|
||||
import apex
|
||||
ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def is_zero_supported_optimizer(optimizer):
|
||||
|
@ -30,6 +30,7 @@ collections:
|
||||
output: true
|
||||
permalink: /:collection/:path/
|
||||
order:
|
||||
- advanced-install.md
|
||||
- getting-started.md
|
||||
- azure.md
|
||||
- cifar-10.md
|
||||
|
86
docs/_tutorials/advanced-install.md
Normal file
86
docs/_tutorials/advanced-install.md
Normal file
@ -0,0 +1,86 @@
|
||||
---
|
||||
title: "Installation Details"
|
||||
date: 2020-10-28
|
||||
---
|
||||
|
||||
The quickest way to get started with DeepSpeed is via pip, this will install
|
||||
the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA
|
||||
versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer
|
||||
to as our 'ops'. By default, all of these extensions/ops will be built
|
||||
just-in-time (JIT) using [torch's JIT C++ extension loader that relies on
|
||||
ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and
|
||||
dynamically link them at runtime.
|
||||
|
||||
```bash
|
||||
pip install deepspeed
|
||||
```
|
||||
|
||||
After installation you can validate your install and see which ops your machine
|
||||
is compatible with via the DeepSpeed environment report with `ds_report` or
|
||||
`python -m deepspeed.env_report`. We've found this report useful when debugging
|
||||
DeepSpeed install or compatibility issues.
|
||||
|
||||
```bash
|
||||
ds_report
|
||||
```
|
||||
|
||||
## Install DeepSpeed from source
|
||||
|
||||
After cloning the DeepSpeed repo from github you can install DeepSpeed in
|
||||
JIT mode via pip (see below). This install should complete
|
||||
quickly since it is not compiling any C++/CUDA source files.
|
||||
|
||||
```bash
|
||||
pip install .
|
||||
```
|
||||
|
||||
For installs spanning multiple nodes we find it useful to install DeepSpeed
|
||||
using the
|
||||
[install.sh](https://github.com/microsoft/DeepSpeed/blob/master/install.sh)
|
||||
script in the repo. This will build a python wheel locally and copy it to all
|
||||
the nodes listed in your hostfile (either given via --hostfile, or defaults to
|
||||
/job/hostfile).
|
||||
|
||||
## Pre-install DeepSpeed Ops
|
||||
|
||||
Sometimes we have found it useful to pre-install either some or all DeepSpeed
|
||||
C++/CUDA ops instead of using the JIT compiled path. In order to support
|
||||
pre-installation we introduce build environment flags to turn on/off building
|
||||
specific ops.
|
||||
|
||||
You can indicate to our installer (either install.sh or pip install) that you
|
||||
want to attempt to install all of our ops by setting the `DS_BUILD_OPS`
|
||||
environment variable to 1, for example:
|
||||
|
||||
```bash
|
||||
DS_BUILD_OPS=1 pip install .
|
||||
```
|
||||
|
||||
We will only install any ops that are compatible with your machine, for more
|
||||
details on which ops are compatible with your system please try our `ds_report`
|
||||
tool described above.
|
||||
|
||||
If you want to install only a specific op (e.g., FusedLamb) you can view the op
|
||||
specific build environment variable (set as `BUILD_VAR`) in the corresponding
|
||||
op builder class in the
|
||||
[op\_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder)
|
||||
directory. For example to install only the Fused Lamb op you would install via:
|
||||
|
||||
```bash
|
||||
DS_BUILD_FUSED_LAMB=1 pip install .
|
||||
```
|
||||
|
||||
## Feature specific dependencies
|
||||
|
||||
Some DeepSpeed features require specific dependencies outside of the general
|
||||
dependencies of DeepSpeed.
|
||||
|
||||
* Python package dependencies per feature/op please
|
||||
see our [requirements
|
||||
directory](https://github.com/microsoft/DeepSpeed/tree/master/requirements).
|
||||
|
||||
* We attempt to keep the system level dependencies to a minimum, however some features do require special system-level packages. Please see our `ds_report` tool output to see if you are missing any system-level packages for a given feature.
|
||||
|
||||
## Pre-compiled DeepSpeed builds from PyPI
|
||||
|
||||
Coming soon
|
@ -7,9 +7,9 @@ date: 2020-05-15
|
||||
|
||||
## Installation
|
||||
|
||||
* Installing is as simple as `pip install deepspeed`, [see more details](/tutorials/advanced-install/).
|
||||
* Please see our [Azure tutorial](/tutorials/azure/) to get started with DeepSpeed on Azure!
|
||||
* If you're not on Azure, we recommend using our docker image via `docker pull deepspeed/deepspeed:latest` which contains a pre-installed version of DeepSpeed and all the necessary dependencies.
|
||||
* If you want to install DeepSpeed manually, we provide an install script `install.sh` to help install on a local machine or across an entire cluster.
|
||||
|
||||
## Writing DeepSpeed Models
|
||||
DeepSpeed model training is accomplished using the DeepSpeed engine. The engine
|
||||
|
@ -28,8 +28,9 @@ initiative to enable next-generation AI capabilities at scale, where you can fin
|
||||
information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale).
|
||||
|
||||
# What's New?
|
||||
* [2020/10/28] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
|
||||
* [DeepSpeed: Extreme-scale model training for everyone]({{ site.press_release_v3 }})
|
||||
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
|
||||
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
|
||||
* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone]({{ site.press_release_v3 }})
|
||||
* [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
|
||||
* [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html)
|
||||
* [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)
|
||||
|
126
install.sh
126
install.sh
@ -15,16 +15,13 @@ By default will install deepspeed and all third party dependecies accross all ma
|
||||
hostfile (hostfile: /job/hostfile). If no hostfile exists, will only install locally
|
||||
|
||||
[optional]
|
||||
-d, --deepspeed_only Install only deepspeed and no third party dependencies
|
||||
-t, --third_party_only Install only third party dependencies and not deepspeed
|
||||
-l, --local_only Install only on local machine
|
||||
-s, --pip_sudo Run pip install with sudo (default: no sudo)
|
||||
-r, --allow_sudo Allow script to be run by root (probably don't want this, instead use --pip_sudo)
|
||||
-n, --no_clean Do not clean prior build state, by default prior build files are removed before building wheels
|
||||
-m, --pip_mirror Use the specified pip mirror (default: the default pip mirror)
|
||||
-H, --hostfile Path to MPI-style hostfile (default: /job/hostfile)
|
||||
-a, --apex_commit Install a specific commit hash of apex, instead of the one deepspeed points to
|
||||
-k, --skip_requirements Skip installing DeepSpeed requirements
|
||||
-v, --verbose Verbose logging
|
||||
-h, --help This help text
|
||||
"""
|
||||
}
|
||||
@ -42,27 +39,12 @@ apex_commit=""
|
||||
skip_requirements=0
|
||||
allow_sudo=0
|
||||
no_clean=0
|
||||
verbose=0
|
||||
|
||||
while [[ $# -gt 0 ]]
|
||||
do
|
||||
key="$1"
|
||||
case $key in
|
||||
-d|--deepspeed_only)
|
||||
deepspeed_install=1;
|
||||
third_party_install=0;
|
||||
ds_only=1;
|
||||
shift
|
||||
;;
|
||||
-t|--third_party_only)
|
||||
deepspeed_install=0;
|
||||
third_party_install=1;
|
||||
tp_only=1;
|
||||
shift
|
||||
;;
|
||||
-l|--local_only)
|
||||
local_only=1;
|
||||
shift
|
||||
;;
|
||||
-s|--pip_sudo)
|
||||
pip_sudo=1;
|
||||
shift
|
||||
@ -72,13 +54,8 @@ case $key in
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-a|--apex_commit)
|
||||
apex_commit=$2;
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-k|--skip_requirements)
|
||||
skip_requirements=1;
|
||||
-v|--verbose)
|
||||
verbose=1;
|
||||
shift
|
||||
;;
|
||||
-r|--allow_sudo)
|
||||
@ -126,12 +103,18 @@ if [ "$ds_only" == "1" ] && [ "$tp_only" == "1" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$verbose" == "1" ]; then
|
||||
VERBOSE="-v"
|
||||
else
|
||||
VERBOSE=""
|
||||
fi
|
||||
|
||||
rm_if_exist() {
|
||||
echo "Attempting to remove $1"
|
||||
if [ -f $1 ]; then
|
||||
rm -v $1
|
||||
rm $VERBOSE $1
|
||||
elif [ -d $1 ]; then
|
||||
rm -vr $1
|
||||
rm -r $VERBOSE $1
|
||||
fi
|
||||
}
|
||||
|
||||
@ -141,10 +124,6 @@ if [ "$no_clean" == "0" ]; then
|
||||
rm_if_exist dist
|
||||
rm_if_exist build
|
||||
rm_if_exist deepspeed.egg-info
|
||||
# remove apex build files
|
||||
rm_if_exist third_party/apex/dist
|
||||
rm_if_exist third_party/apex/build
|
||||
rm_if_exist third_party/apex/apex.egg-info
|
||||
fi
|
||||
|
||||
if [ "$pip_sudo" == "1" ]; then
|
||||
@ -154,60 +133,25 @@ else
|
||||
fi
|
||||
|
||||
if [ "$pip_mirror" != "" ]; then
|
||||
PIP_INSTALL="pip install -v -i $pip_mirror"
|
||||
PIP_INSTALL="pip install $VERBOSE -i $pip_mirror"
|
||||
else
|
||||
PIP_INSTALL="pip install -v"
|
||||
PIP_INSTALL="pip install $VERBOSE"
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $hostfile ]; then
|
||||
echo "No hostfile exists at $hostfile, installing locally"
|
||||
local_only=1
|
||||
fi
|
||||
|
||||
if [ "$skip_requirements" == "0" ]; then
|
||||
# Ensure dependencies are installed locally
|
||||
$PIP_SUDO $PIP_INSTALL -r requirements/requirements.txt
|
||||
fi
|
||||
|
||||
# Build wheels
|
||||
if [ "$third_party_install" == "1" ]; then
|
||||
echo "Checking out sub-module(s)"
|
||||
git submodule update --init --recursive
|
||||
|
||||
echo "Building apex wheel"
|
||||
cd third_party/apex
|
||||
|
||||
if [ "$apex_commit" != "" ]; then
|
||||
echo "Installing a non-standard version of apex at commit: $apex_commit"
|
||||
git fetch
|
||||
git checkout $apex_commit
|
||||
fi
|
||||
|
||||
python setup.py -v --cpp_ext --cuda_ext bdist_wheel
|
||||
cd -
|
||||
|
||||
echo "Installing apex locally so that deepspeed will build"
|
||||
$PIP_SUDO pip uninstall -y apex
|
||||
$PIP_SUDO $PIP_INSTALL third_party/apex/dist/apex*.whl
|
||||
fi
|
||||
if [ "$deepspeed_install" == "1" ]; then
|
||||
echo "Building deepspeed wheel"
|
||||
python setup.py -v bdist_wheel
|
||||
fi
|
||||
echo "Building deepspeed wheel"
|
||||
python setup.py $VERBOSE bdist_wheel
|
||||
|
||||
if [ "$local_only" == "1" ]; then
|
||||
if [ "$deepspeed_install" == "1" ]; then
|
||||
echo "Installing deepspeed"
|
||||
$PIP_SUDO pip uninstall -y deepspeed
|
||||
$PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl
|
||||
# -I to exclude local directory files
|
||||
python -I basic_install_test.py
|
||||
if [ $? == 0 ]; then
|
||||
echo "Installation is successful"
|
||||
else
|
||||
echo "Installation failed"
|
||||
fi
|
||||
fi
|
||||
echo "Installing deepspeed"
|
||||
$PIP_SUDO pip uninstall -y deepspeed
|
||||
$PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl
|
||||
ds_report
|
||||
else
|
||||
local_path=`pwd`
|
||||
if [ -f $hostfile ]; then
|
||||
@ -216,28 +160,16 @@ else
|
||||
echo "hostfile not found, cannot proceed"
|
||||
exit 1
|
||||
fi
|
||||
export PDSH_RCMD_TYPE=ssh;
|
||||
export PDSH_RCMD_TYPE=ssh
|
||||
tmp_wheel_path="/tmp/deepspeed_wheels"
|
||||
|
||||
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi"
|
||||
pdcp -w $hosts requirements/requirements.txt ${tmp_wheel_path}/
|
||||
if [ "$skip_requirements" == "0" ]; then
|
||||
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
|
||||
fi
|
||||
if [ "$third_party_install" == "1" ]; then
|
||||
pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex"
|
||||
pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/
|
||||
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/apex*.whl"
|
||||
pdsh -w $hosts 'python -c "import apex"'
|
||||
fi
|
||||
if [ "$deepspeed_install" == "1" ]; then
|
||||
echo "Installing deepspeed"
|
||||
pdsh -w $hosts "$PIP_SUDO pip uninstall -y deepspeed"
|
||||
pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/
|
||||
pdcp -w $hosts basic_install_test.py $tmp_wheel_path/
|
||||
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl"
|
||||
pdsh -w $hosts "python $tmp_wheel_path/basic_install_test.py"
|
||||
echo "Installation is successful"
|
||||
fi
|
||||
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl $tmp_wheel_path/basic_install_test.py; rmdir $tmp_wheel_path; fi"
|
||||
|
||||
echo "Installing deepspeed"
|
||||
pdsh -w $hosts "$PIP_SUDO pip uninstall -y deepspeed"
|
||||
pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/
|
||||
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl"
|
||||
pdsh -w $hosts "ds_report"
|
||||
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rmdir $tmp_wheel_path; fi"
|
||||
fi
|
||||
|
20
op_builder/__init__.py
Normal file
20
op_builder/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
from .cpu_adam import CPUAdamBuilder
|
||||
from .fused_adam import FusedAdamBuilder
|
||||
from .fused_lamb import FusedLambBuilder
|
||||
from .sparse_attn import SparseAttnBuilder
|
||||
from .transformer import TransformerBuilder
|
||||
from .stochastic_transformer import StochasticTransformerBuilder
|
||||
from .utils import UtilsBuilder
|
||||
|
||||
# TODO: infer this list instead of hard coded
|
||||
# List of all available ops
|
||||
__op_builders__ = [
|
||||
CPUAdamBuilder(),
|
||||
FusedAdamBuilder(),
|
||||
FusedLambBuilder(),
|
||||
SparseAttnBuilder(),
|
||||
TransformerBuilder(),
|
||||
StochasticTransformerBuilder(),
|
||||
UtilsBuilder()
|
||||
]
|
||||
ALL_OPS = {op.name: op for op in __op_builders__}
|
245
op_builder/builder.py
Normal file
245
op_builder/builder.py
Normal file
@ -0,0 +1,245 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
YELLOW = '\033[93m'
|
||||
END = '\033[0m'
|
||||
WARNING = f"{YELLOW} [WARNING] {END}"
|
||||
|
||||
DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
|
||||
|
||||
|
||||
def assert_no_cuda_mismatch():
|
||||
import torch.utils.cpp_extension
|
||||
cuda_home = torch.utils.cpp_extension.CUDA_HOME
|
||||
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
|
||||
# Ensure there is not a cuda version mismatch between torch and nvcc compiler
|
||||
output = subprocess.check_output([cuda_home + "/bin/nvcc",
|
||||
"-V"],
|
||||
universal_newlines=True)
|
||||
output_split = output.split()
|
||||
release_idx = output_split.index("release")
|
||||
release = output_split[release_idx + 1].replace(',', '').split(".")
|
||||
# Ignore patch versions, only look at major + minor
|
||||
installed_cuda_version = ".".join(release[:2])
|
||||
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
|
||||
# This is a show-stopping error, should probably not proceed past this
|
||||
if installed_cuda_version != torch_cuda_version:
|
||||
raise Exception(
|
||||
f"Installed CUDA version {installed_cuda_version} does not match the "
|
||||
f"version torch was compiled with {torch.version.cuda}, unable to compile "
|
||||
"cuda/cpp extensions without a matching cuda version.")
|
||||
|
||||
|
||||
def assert_torch_info(torch_info):
|
||||
install_torch_version = torch_info['version']
|
||||
install_cuda_version = torch_info['cuda_version']
|
||||
|
||||
current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
|
||||
current_torch_version = ".".join(torch.__version__.split('.')[:2])
|
||||
|
||||
if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version:
|
||||
raise RuntimeError(
|
||||
"PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
|
||||
"with a different version than what is being used at runtime. Please re-install "
|
||||
f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
|
||||
f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:"
|
||||
f"torch={current_torch_version}, cuda={current_cuda_version}")
|
||||
|
||||
|
||||
class OpBuilder(ABC):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.jit_mode = False
|
||||
|
||||
@abstractmethod
|
||||
def absolute_name(self):
|
||||
'''
|
||||
Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
|
||||
will be installed as something like: deepspeed/ops/adam/cpu_adam.so
|
||||
'''
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sources(self):
|
||||
'''
|
||||
Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
|
||||
'''
|
||||
pass
|
||||
|
||||
def include_paths(self):
|
||||
'''
|
||||
Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
|
||||
'''
|
||||
return []
|
||||
|
||||
def nvcc_args(self):
|
||||
'''
|
||||
Returns optional list of compiler flags to forward to nvcc when building CUDA sources
|
||||
'''
|
||||
return []
|
||||
|
||||
def cxx_args(self):
|
||||
'''
|
||||
Returns optional list of compiler flags to forward to the build
|
||||
'''
|
||||
return []
|
||||
|
||||
def is_compatible(self):
|
||||
'''
|
||||
Check if all non-python dependencies are satisfied to build this op
|
||||
'''
|
||||
return True
|
||||
|
||||
def python_requirements(self):
|
||||
'''
|
||||
Override if op wants to define special dependencies, otherwise will
|
||||
take self.name and load requirements-<op-name>.txt if it exists.
|
||||
'''
|
||||
path = f'requirements/requirements-{self.name}.txt'
|
||||
requirements = []
|
||||
if os.path.isfile(path):
|
||||
with open(path, 'r') as fd:
|
||||
requirements = [r.strip() for r in fd.readlines()]
|
||||
return requirements
|
||||
|
||||
def command_exists(self, cmd):
|
||||
if '|' in cmd:
|
||||
cmds = cmd.split("|")
|
||||
else:
|
||||
cmds = [cmd]
|
||||
valid = False
|
||||
for cmd in cmds:
|
||||
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
|
||||
valid = valid or result.wait() == 0
|
||||
|
||||
if not valid and len(cmds) > 1:
|
||||
print(
|
||||
f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!"
|
||||
)
|
||||
elif not valid and len(cmds) == 1:
|
||||
print(
|
||||
f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!"
|
||||
)
|
||||
return valid
|
||||
|
||||
def warning(self, msg):
|
||||
print(f"{WARNING} {msg}")
|
||||
|
||||
def deepspeed_src_path(self, code_path):
|
||||
if os.path.isabs(code_path):
|
||||
return code_path
|
||||
else:
|
||||
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
|
||||
|
||||
def builder(self):
|
||||
from torch.utils.cpp_extension import CppExtension
|
||||
return CppExtension(name=self.absolute_name(),
|
||||
sources=self.sources(),
|
||||
include_dirs=self.include_paths(),
|
||||
extra_compile_args={'cxx': self.cxx_args()})
|
||||
|
||||
def load(self, verbose=True):
|
||||
from ...git_version_info import installed_ops, torch_info
|
||||
if installed_ops[self.name]:
|
||||
# Ensure the op we're about to load was compiled with the same
|
||||
# torch/cuda versions we are currently using at runtime.
|
||||
if isinstance(self, CUDAOpBuilder):
|
||||
assert_torch_info(torch_info)
|
||||
|
||||
return importlib.import_module(self.absolute_name())
|
||||
else:
|
||||
return self.jit_load(verbose)
|
||||
|
||||
def jit_load(self, verbose=True):
|
||||
if not self.is_compatible():
|
||||
raise RuntimeError(
|
||||
f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue."
|
||||
)
|
||||
try:
|
||||
import ninja
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
f"Unable to JIT load the {self.name} op due to ninja not being installed."
|
||||
)
|
||||
|
||||
if isinstance(self, CUDAOpBuilder):
|
||||
assert_no_cuda_mismatch()
|
||||
|
||||
self.jit_mode = True
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
# Ensure directory exists to prevent race condition in some cases
|
||||
ext_path = os.path.join(
|
||||
os.environ.get('TORCH_EXTENSIONS_DIR',
|
||||
DEFAULT_TORCH_EXTENSION_PATH),
|
||||
self.name)
|
||||
os.makedirs(ext_path, exist_ok=True)
|
||||
|
||||
start_build = time.time()
|
||||
op_module = load(
|
||||
name=self.name,
|
||||
sources=[self.deepspeed_src_path(path) for path in self.sources()],
|
||||
extra_include_paths=[
|
||||
self.deepspeed_src_path(path) for path in self.include_paths()
|
||||
],
|
||||
extra_cflags=self.cxx_args(),
|
||||
extra_cuda_cflags=self.nvcc_args(),
|
||||
verbose=verbose)
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
print(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
return op_module
|
||||
|
||||
|
||||
class CUDAOpBuilder(OpBuilder):
|
||||
def compute_capability_args(self, cross_compile_archs=['60', '61', '70']):
|
||||
args = []
|
||||
if self.jit_mode:
|
||||
# Compile for underlying architecture since we know it at runtime
|
||||
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability()
|
||||
compute_capability = f"{CC_MAJOR}{CC_MINOR}"
|
||||
args.append('-gencode')
|
||||
args.append(
|
||||
f'arch=compute_{compute_capability},code=compute_{compute_capability}')
|
||||
else:
|
||||
# Cross-compile mode, compile for various architectures
|
||||
for compute_capability in cross_compile_archs:
|
||||
args.append('-gencode')
|
||||
args.append(
|
||||
f'arch=compute_{compute_capability},code=compute_{compute_capability}'
|
||||
)
|
||||
return args
|
||||
|
||||
def version_dependent_macros(self):
|
||||
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
version_ge_1_1 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
|
||||
version_ge_1_1 = ['-DVERSION_GE_1_1']
|
||||
version_ge_1_3 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
|
||||
version_ge_1_3 = ['-DVERSION_GE_1_3']
|
||||
version_ge_1_5 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
|
||||
version_ge_1_5 = ['-DVERSION_GE_1_5']
|
||||
return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
|
||||
|
||||
def is_compatible(self):
|
||||
return super().is_compatible()
|
||||
|
||||
def builder(self):
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
assert_no_cuda_mismatch()
|
||||
return CUDAExtension(name=self.absolute_name(),
|
||||
sources=self.sources(),
|
||||
include_dirs=self.include_paths(),
|
||||
extra_compile_args={
|
||||
'cxx': self.cxx_args(),
|
||||
'nvcc': self.nvcc_args()
|
||||
})
|
77
op_builder/cpu_adam.py
Normal file
77
op_builder/cpu_adam.py
Normal file
@ -0,0 +1,77 @@
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
from .builder import CUDAOpBuilder
|
||||
|
||||
|
||||
class CPUAdamBuilder(CUDAOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CPU_ADAM"
|
||||
NAME = "cpu_adam"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.adam.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/custom_cuda_kernel.cu']
|
||||
|
||||
def include_paths(self):
|
||||
CUDA_INCLUDE = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")
|
||||
return ['csrc/includes', CUDA_INCLUDE]
|
||||
|
||||
def available_vector_instructions(self):
|
||||
try:
|
||||
import cpufeature
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam'
|
||||
)
|
||||
return {}
|
||||
|
||||
cpu_vector_instructions = {}
|
||||
try:
|
||||
cpu_vector_instructions = cpufeature.CPUFeature
|
||||
except _:
|
||||
warnings.warn(
|
||||
f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam'
|
||||
)
|
||||
return {}
|
||||
|
||||
return cpu_vector_instructions
|
||||
|
||||
def cxx_args(self):
|
||||
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
|
||||
cpu_info = self.available_vector_instructions()
|
||||
SIMD_WIDTH = ''
|
||||
if 'Intel' in cpu_info.get('VendorId', ''):
|
||||
if cpu_info.get('AVX512f', False):
|
||||
SIMD_WIDTH = '-D__AVX512__'
|
||||
elif cpu_info.get('AVX2', False):
|
||||
SIMD_WIDTH = '-D__AVX256__'
|
||||
|
||||
return [
|
||||
'-O3',
|
||||
'-std=c++14',
|
||||
f'-L{CUDA_LIB64}',
|
||||
'-lcudart',
|
||||
'-lcublas',
|
||||
'-g',
|
||||
'-Wno-reorder',
|
||||
'-march=native',
|
||||
'-fopenmp',
|
||||
SIMD_WIDTH
|
||||
]
|
||||
|
||||
def nvcc_args(self):
|
||||
args = [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__'
|
||||
]
|
||||
args += self.compute_capability_args()
|
||||
return args
|
25
op_builder/fused_adam.py
Normal file
25
op_builder/fused_adam.py
Normal file
@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from .builder import CUDAOpBuilder
|
||||
|
||||
|
||||
class FusedAdamBuilder(CUDAOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_FUSED_ADAM"
|
||||
NAME = "fused_adam"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.adam.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/adam/fused_adam_frontend.cpp', 'csrc/adam/multi_tensor_adam.cu']
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/includes']
|
||||
|
||||
def cxx_args(self):
|
||||
return ['-O3'] + self.version_dependent_macros()
|
||||
|
||||
def nvcc_args(self):
|
||||
return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros()
|
25
op_builder/fused_lamb.py
Normal file
25
op_builder/fused_lamb.py
Normal file
@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from .builder import CUDAOpBuilder
|
||||
|
||||
|
||||
class FusedLambBuilder(CUDAOpBuilder):
|
||||
BUILD_VAR = 'DS_BUILD_FUSED_LAMB'
|
||||
NAME = "fused_lamb"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.lamb.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/lamb/fused_lamb_cuda.cpp', 'csrc/lamb/fused_lamb_cuda_kernel.cu']
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/includes']
|
||||
|
||||
def cxx_args(self):
|
||||
return ['-O3'] + self.version_dependent_macros()
|
||||
|
||||
def nvcc_args(self):
|
||||
return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros()
|
36
op_builder/sparse_attn.py
Normal file
36
op_builder/sparse_attn.py
Normal file
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
import warnings
|
||||
from .builder import OpBuilder
|
||||
|
||||
|
||||
class SparseAttnBuilder(OpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_SPARSE_ATTN"
|
||||
NAME = "sparse_attn"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.sparse_attention.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/sparse_attention/utils.cpp']
|
||||
|
||||
def cxx_args(self):
|
||||
return ['-O2', '-fopenmp']
|
||||
|
||||
def is_compatible(self):
|
||||
# Check to see if llvm and cmake are installed since they are dependencies
|
||||
required_commands = ['llvm-config|llvm-config-9', 'cmake']
|
||||
command_status = list(map(self.command_exists, required_commands))
|
||||
deps_compatible = all(command_status)
|
||||
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5
|
||||
if not torch_compatible:
|
||||
self.warning(
|
||||
f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
|
||||
)
|
||||
|
||||
return super().is_compatible() and deps_compatible and torch_compatible
|
18
op_builder/stochastic_transformer.py
Normal file
18
op_builder/stochastic_transformer.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from .transformer import TransformerBuilder
|
||||
|
||||
|
||||
class StochasticTransformerBuilder(TransformerBuilder):
|
||||
BUILD_VAR = "DS_BUILD_STOCHASTIC_TRANSFORMER"
|
||||
NAME = "stochastic_transformer"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.transformer.{self.NAME}_op'
|
||||
|
||||
def nvcc_args(self):
|
||||
args = super().nvcc_args()
|
||||
args.append('-D__STOCHASTIC_MODE__')
|
||||
return args
|
44
op_builder/transformer.py
Normal file
44
op_builder/transformer.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from .builder import CUDAOpBuilder
|
||||
|
||||
|
||||
class TransformerBuilder(CUDAOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_TRANSFORMER"
|
||||
NAME = "transformer"
|
||||
|
||||
def __init__(self, name=None):
|
||||
name = self.NAME if name is None else name
|
||||
super().__init__(name=name)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.transformer.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return [
|
||||
'csrc/transformer/ds_transformer_cuda.cpp',
|
||||
'csrc/transformer/cublas_wrappers.cu',
|
||||
'csrc/transformer/transform_kernels.cu',
|
||||
'csrc/transformer/gelu_kernels.cu',
|
||||
'csrc/transformer/dropout_kernels.cu',
|
||||
'csrc/transformer/normalize_kernels.cu',
|
||||
'csrc/transformer/softmax_kernels.cu',
|
||||
'csrc/transformer/general_kernels.cu'
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/includes']
|
||||
|
||||
def nvcc_args(self):
|
||||
args = [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__'
|
||||
]
|
||||
|
||||
return args + self.compute_capability_args()
|
||||
|
||||
def cxx_args(self):
|
||||
return ['-O3', '-std=c++14', '-g', '-Wno-reorder']
|
15
op_builder/utils.py
Normal file
15
op_builder/utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
from .builder import OpBuilder
|
||||
|
||||
|
||||
class UtilsBuilder(OpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_UTILS"
|
||||
NAME = "utils"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/utils/flatten_unflatten.cpp']
|
@ -2,5 +2,6 @@ torch>=1.2
|
||||
torchvision>=0.4.0
|
||||
tqdm
|
||||
psutil
|
||||
cpufeature
|
||||
tensorboardX==1.8
|
||||
ninja
|
||||
cpufeature
|
||||
|
365
setup.py
365
setup.py
@ -16,7 +16,7 @@ import warnings
|
||||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension
|
||||
|
||||
VERSION = "0.3.0"
|
||||
import op_builder
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
@ -24,88 +24,33 @@ def fetch_requirements(path):
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
def available_vector_instructions():
|
||||
try:
|
||||
import cpufeature
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam'
|
||||
)
|
||||
return {}
|
||||
|
||||
cpu_vector_instructions = {}
|
||||
try:
|
||||
cpu_vector_instructions = cpufeature.CPUFeature
|
||||
except _:
|
||||
warnings.warn(
|
||||
f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam'
|
||||
)
|
||||
return {}
|
||||
|
||||
return cpu_vector_instructions
|
||||
|
||||
|
||||
install_requires = fetch_requirements('requirements/requirements.txt')
|
||||
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
|
||||
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
|
||||
extras_require = {
|
||||
'1bit_adam': fetch_requirements('requirements/requirements-1bit-adam.txt'),
|
||||
'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'),
|
||||
'dev': fetch_requirements('requirements/requirements-dev.txt'),
|
||||
}
|
||||
|
||||
# If MPI is available add 1bit-adam requirements
|
||||
if torch.cuda.is_available():
|
||||
if shutil.which('ompi_info') or shutil.which('mpiname'):
|
||||
onebit_adam_requires = fetch_requirements(
|
||||
'requirements/requirements-1bit-adam.txt')
|
||||
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
|
||||
install_requires += onebit_adam_requires
|
||||
cupy = f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}"
|
||||
extras_require['1bit_adam'].append(cupy)
|
||||
|
||||
# Constants for each op
|
||||
LAMB = "lamb"
|
||||
TRANSFORMER = "transformer"
|
||||
SPARSE_ATTN = "sparse-attn"
|
||||
CPU_ADAM = "cpu-adam"
|
||||
|
||||
cpu_vector_instructions = available_vector_instructions()
|
||||
|
||||
# Build environment variables for custom builds
|
||||
DS_BUILD_LAMB_MASK = 1
|
||||
DS_BUILD_TRANSFORMER_MASK = 10
|
||||
DS_BUILD_SPARSE_ATTN_MASK = 100
|
||||
DS_BUILD_CPU_ADAM_MASK = 1000
|
||||
|
||||
# Allow for build_cuda to turn on or off all ops
|
||||
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK
|
||||
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS
|
||||
|
||||
# Set default of each op based on if build_cuda is set
|
||||
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
|
||||
DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM', 0)) * DS_BUILD_CPU_ADAM_MASK
|
||||
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
|
||||
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
|
||||
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
|
||||
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
|
||||
OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK
|
||||
|
||||
# Final effective mask is the bitwise OR of each op
|
||||
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN
|
||||
| DS_BUILD_CPU_ADAM)
|
||||
|
||||
install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, CPU_ADAM], False)
|
||||
if BUILD_MASK & DS_BUILD_LAMB:
|
||||
install_ops[LAMB] = True
|
||||
if BUILD_MASK & DS_BUILD_CPU_ADAM:
|
||||
install_ops[CPU_ADAM] = True
|
||||
if BUILD_MASK & DS_BUILD_TRANSFORMER:
|
||||
install_ops[TRANSFORMER] = True
|
||||
if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
|
||||
install_ops[SPARSE_ATTN] = True
|
||||
if len(install_ops) == 0:
|
||||
print("Building without any cuda/cpp extensions")
|
||||
print(f'BUILD_MASK={BUILD_MASK}, install_ops={install_ops}')
|
||||
# Make an [all] extra that installs all needed dependencies
|
||||
all_extras = set()
|
||||
for extra in extras_require.items():
|
||||
for req in extra[1]:
|
||||
all_extras.add(req)
|
||||
extras_require['all'] = list(all_extras)
|
||||
|
||||
cmdclass = {}
|
||||
|
||||
# For any pre-installed ops force disable ninja
|
||||
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
|
||||
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
TORCH_MAJOR = torch.__version__.split('.')[0]
|
||||
TORCH_MINOR = torch.__version__.split('.')[1]
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
# Fix to allow docker buils, similar to https://github.com/NVIDIA/apex/issues/486
|
||||
@ -116,230 +61,118 @@ if not torch.cuda.is_available():
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
|
||||
version_ge_1_1 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
|
||||
version_ge_1_1 = ['-DVERSION_GE_1_1']
|
||||
version_ge_1_3 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
|
||||
version_ge_1_3 = ['-DVERSION_GE_1_3']
|
||||
version_ge_1_5 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
|
||||
version_ge_1_5 = ['-DVERSION_GE_1_5']
|
||||
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
|
||||
|
||||
SIMD_WIDTH = ''
|
||||
if cpu_vector_instructions.get('AVX512f', False):
|
||||
SIMD_WIDTH = '-D__AVX512__'
|
||||
elif cpu_vector_instructions.get('AVX2', False):
|
||||
SIMD_WIDTH = '-D__AVX256__'
|
||||
print("SIMD_WIDTH = ", SIMD_WIDTH)
|
||||
|
||||
ext_modules = []
|
||||
|
||||
## Lamb ##
|
||||
if BUILD_MASK & DS_BUILD_LAMB:
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='deepspeed.ops.lamb.fused_lamb_cuda',
|
||||
sources=[
|
||||
'csrc/lamb/fused_lamb_cuda.cpp',
|
||||
'csrc/lamb/fused_lamb_cuda_kernel.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': [
|
||||
'-O3',
|
||||
] + version_dependent_macros,
|
||||
'nvcc': ['-O3',
|
||||
'--use_fast_math'] + version_dependent_macros
|
||||
}))
|
||||
from op_builder import ALL_OPS
|
||||
|
||||
## Adam ##
|
||||
if BUILD_MASK & DS_BUILD_CPU_ADAM:
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op',
|
||||
sources=[
|
||||
'csrc/adam/cpu_adam.cpp',
|
||||
'csrc/adam/custom_cuda_kernel.cu',
|
||||
],
|
||||
include_dirs=['csrc/includes',
|
||||
'/usr/local/cuda/include'],
|
||||
extra_compile_args={
|
||||
'cxx': [
|
||||
'-O3',
|
||||
'-std=c++14',
|
||||
'-L/usr/local/cuda/lib64',
|
||||
'-lcudart',
|
||||
'-lcublas',
|
||||
'-g',
|
||||
'-Wno-reorder',
|
||||
'-march=native',
|
||||
'-fopenmp',
|
||||
SIMD_WIDTH
|
||||
],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-gencode',
|
||||
'arch=compute_61,code=compute_61',
|
||||
'-gencode',
|
||||
'arch=compute_70,code=compute_70',
|
||||
'-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__'
|
||||
]
|
||||
}))
|
||||
|
||||
## Transformer ##
|
||||
if BUILD_MASK & DS_BUILD_TRANSFORMER:
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='deepspeed.ops.transformer.transformer_cuda',
|
||||
sources=[
|
||||
'csrc/transformer/ds_transformer_cuda.cpp',
|
||||
'csrc/transformer/cublas_wrappers.cu',
|
||||
'csrc/transformer/transform_kernels.cu',
|
||||
'csrc/transformer/gelu_kernels.cu',
|
||||
'csrc/transformer/dropout_kernels.cu',
|
||||
'csrc/transformer/normalize_kernels.cu',
|
||||
'csrc/transformer/softmax_kernels.cu',
|
||||
'csrc/transformer/general_kernels.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3',
|
||||
'-std=c++14',
|
||||
'-g',
|
||||
'-Wno-reorder'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-gencode',
|
||||
'arch=compute_61,code=compute_61',
|
||||
'-gencode',
|
||||
'arch=compute_60,code=compute_60',
|
||||
'-gencode',
|
||||
'arch=compute_70,code=compute_70',
|
||||
'-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__'
|
||||
]
|
||||
}))
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='deepspeed.ops.transformer.stochastic_transformer_cuda',
|
||||
sources=[
|
||||
'csrc/transformer/ds_transformer_cuda.cpp',
|
||||
'csrc/transformer/cublas_wrappers.cu',
|
||||
'csrc/transformer/transform_kernels.cu',
|
||||
'csrc/transformer/gelu_kernels.cu',
|
||||
'csrc/transformer/dropout_kernels.cu',
|
||||
'csrc/transformer/normalize_kernels.cu',
|
||||
'csrc/transformer/softmax_kernels.cu',
|
||||
'csrc/transformer/general_kernels.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3',
|
||||
'-std=c++14',
|
||||
'-g',
|
||||
'-Wno-reorder'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-gencode',
|
||||
'arch=compute_61,code=compute_61',
|
||||
'-gencode',
|
||||
'arch=compute_60,code=compute_60',
|
||||
'-gencode',
|
||||
'arch=compute_70,code=compute_70',
|
||||
'-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-D__STOCHASTIC_MODE__'
|
||||
]
|
||||
}))
|
||||
# Default to pre-install kernels to false so we rely on JIT
|
||||
BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0))
|
||||
print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")
|
||||
|
||||
|
||||
def command_exists(cmd):
|
||||
if '|' in cmd:
|
||||
cmds = cmd.split("|")
|
||||
else:
|
||||
cmds = [cmd]
|
||||
valid = False
|
||||
for cmd in cmds:
|
||||
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
|
||||
valid = valid or result.wait() == 0
|
||||
return valid
|
||||
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
|
||||
return result.wait() == 0
|
||||
|
||||
|
||||
## Sparse transformer ##
|
||||
if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
|
||||
# Check to see if llvm and cmake are installed since they are dependencies
|
||||
required_commands = ['llvm-config|llvm-config-9', 'cmake']
|
||||
def op_enabled(op_name):
|
||||
assert hasattr(ALL_OPS[op_name], 'BUILD_VAR'), \
|
||||
f"{op_name} is missing BUILD_VAR field"
|
||||
env_var = ALL_OPS[op_name].BUILD_VAR
|
||||
return int(os.environ.get(env_var, BUILD_OP_DEFAULT))
|
||||
|
||||
command_status = list(map(command_exists, required_commands))
|
||||
if not all(command_status):
|
||||
zipped_status = list(zip(required_commands, command_status))
|
||||
warnings.warn(
|
||||
f'Missing non-python requirements, please install the missing packages: {zipped_status}'
|
||||
)
|
||||
warnings.warn(
|
||||
'Skipping sparse attention installation due to missing required packages')
|
||||
# remove from installed ops list
|
||||
install_ops[SPARSE_ATTN] = False
|
||||
elif TORCH_MAJOR == 1 and TORCH_MINOR >= 5:
|
||||
ext_modules.append(
|
||||
CppExtension(name='deepspeed.ops.sparse_attention.cpp_utils',
|
||||
sources=['csrc/sparse_attention/utils.cpp'],
|
||||
extra_compile_args={'cxx': ['-O2',
|
||||
'-fopenmp']}))
|
||||
# Add sparse attention requirements
|
||||
install_requires += sparse_attn_requires
|
||||
else:
|
||||
warnings.warn('Unable to meet requirements to install sparse attention')
|
||||
# remove from installed ops list
|
||||
install_ops[SPARSE_ATTN] = False
|
||||
|
||||
# Add development requirements
|
||||
install_requires += dev_requires
|
||||
install_ops = dict.fromkeys(ALL_OPS.keys(), False)
|
||||
for op_name, builder in ALL_OPS.items():
|
||||
op_compatible = builder.is_compatible()
|
||||
|
||||
# If op is compatible update install reqs so it can potentially build/run later
|
||||
if op_compatible:
|
||||
reqs = builder.python_requirements()
|
||||
install_requires += builder.python_requirements()
|
||||
|
||||
# If op install enabled, add builder to extensions
|
||||
if op_enabled(op_name) and op_compatible:
|
||||
install_ops[op_name] = op_enabled(op_name)
|
||||
ext_modules.append(builder.builder())
|
||||
|
||||
compatible_ops = {op_name: op.is_compatible() for (op_name, op) in ALL_OPS.items()}
|
||||
|
||||
print(f'Install Ops={install_ops}')
|
||||
|
||||
# Write out version/git info
|
||||
git_hash_cmd = "git rev-parse --short HEAD"
|
||||
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
|
||||
if command_exists('git'):
|
||||
result = subprocess.check_output(git_hash_cmd, shell=True)
|
||||
git_hash = result.decode('utf-8').strip()
|
||||
result = subprocess.check_output(git_branch_cmd, shell=True)
|
||||
git_branch = result.decode('utf-8').strip()
|
||||
if command_exists('git') and 'DS_BUILD_STRING' not in os.environ:
|
||||
try:
|
||||
result = subprocess.check_output(git_hash_cmd, shell=True)
|
||||
git_hash = result.decode('utf-8').strip()
|
||||
result = subprocess.check_output(git_branch_cmd, shell=True)
|
||||
git_branch = result.decode('utf-8').strip()
|
||||
except subprocess.CalledProcessError:
|
||||
git_hash = "unknown"
|
||||
git_branch = "unknown"
|
||||
else:
|
||||
git_hash = "unknown"
|
||||
git_branch = "unknown"
|
||||
print(f"version={VERSION}+{git_hash}, git_hash={git_hash}, git_branch={git_branch}")
|
||||
|
||||
# Parse the DeepSpeed version string from version.txt
|
||||
version_str = open('version.txt', 'r').read().strip()
|
||||
|
||||
# Build specifiers like .devX can be added at install time. Otherwise, add the git hash.
|
||||
# example: DS_BUILD_STR=".dev20201022" python setup.py sdist bdist_wheel
|
||||
#version_str += os.environ.get('DS_BUILD_STRING', f'+{git_hash}')
|
||||
|
||||
# Building wheel for distribution, update version file
|
||||
|
||||
if 'DS_BUILD_STRING' in os.environ:
|
||||
# Build string env specified, probably building for distribution
|
||||
with open('build.txt', 'w') as fd:
|
||||
fd.write(os.environ.get('DS_BUILD_STRING'))
|
||||
version_str += os.environ.get('DS_BUILD_STRING')
|
||||
elif os.path.isfile('build.txt'):
|
||||
# build.txt exists, probably installing from distribution
|
||||
with open('build.txt', 'r') as fd:
|
||||
version_str += fd.read().strip()
|
||||
else:
|
||||
# None of the above, probably installing from source
|
||||
version_str += f'+{git_hash}'
|
||||
|
||||
torch_version = ".".join([TORCH_MAJOR, TORCH_MINOR])
|
||||
cuda_version = ".".join(torch.version.cuda.split('.')[:2])
|
||||
torch_info = {"version": torch_version, "cuda_version": cuda_version}
|
||||
|
||||
print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}")
|
||||
with open('deepspeed/git_version_info_installed.py', 'w') as fd:
|
||||
fd.write(f"version='{VERSION}+{git_hash}'\n")
|
||||
fd.write(f"version='{version_str}'\n")
|
||||
fd.write(f"git_hash='{git_hash}'\n")
|
||||
fd.write(f"git_branch='{git_branch}'\n")
|
||||
fd.write(f"installed_ops={install_ops}\n")
|
||||
fd.write(f"compatible_ops={compatible_ops}\n")
|
||||
fd.write(f"torch_info={torch_info}\n")
|
||||
|
||||
print(f'install_requires={install_requires}')
|
||||
print(f'compatible_ops={compatible_ops}')
|
||||
print(f'ext_modules={ext_modules}')
|
||||
|
||||
setup(name='deepspeed',
|
||||
version=f"{VERSION}+{git_hash}",
|
||||
version=version_str,
|
||||
description='DeepSpeed library',
|
||||
author='DeepSpeed Team',
|
||||
author_email='deepspeed@microsoft.com',
|
||||
url='http://deepspeed.ai',
|
||||
install_requires=install_requires,
|
||||
extras_require=extras_require,
|
||||
packages=find_packages(exclude=["docker",
|
||||
"third_party",
|
||||
"csrc"]),
|
||||
package_data={'deepspeed.ops.sparse_attention.trsrc': ['*.tr']},
|
||||
scripts=['bin/deepspeed',
|
||||
'bin/deepspeed.pt',
|
||||
'bin/ds',
|
||||
'bin/ds_ssh'],
|
||||
"third_party"]),
|
||||
include_package_data=True,
|
||||
scripts=[
|
||||
'bin/deepspeed',
|
||||
'bin/deepspeed.pt',
|
||||
'bin/ds',
|
||||
'bin/ds_ssh',
|
||||
'bin/ds_report'
|
||||
],
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
|
@ -363,10 +363,18 @@ except ImportError:
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
pdtype = x.dtype
|
||||
x = x.float()
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
return self.weight * x.to(pdtype) + self.bias
|
||||
|
||||
#def forward(self, x):
|
||||
# u = x.mean(-1, keepdim=True)
|
||||
# s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
# x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
# return self.weight * x + self.bias
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
|
@ -12,6 +12,8 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.runtime.pipe.topology import *
|
||||
PipeTopo = PipeDataParallelTopology
|
||||
|
||||
from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder
|
||||
|
||||
import argparse
|
||||
import pytest
|
||||
import json
|
||||
@ -152,8 +154,8 @@ def checkpoint_correctness_verification(args,
|
||||
compare_lr_scheduler_states(trained_model, loaded_model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
|
||||
reason="lamb is not installed")
|
||||
@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME],
|
||||
reason="lamb is not compatible")
|
||||
def test_checkpoint_unfused_optimizer(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -264,11 +266,11 @@ def test_checkpoint_fused_optimizer(tmpdir):
|
||||
'Adam'),
|
||||
(2,
|
||||
True,
|
||||
'deepspeed_adam'),
|
||||
'Adam'),
|
||||
])
|
||||
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -320,14 +322,14 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
|
||||
"Adam"),
|
||||
(2,
|
||||
True,
|
||||
'deepspeed_adam'),
|
||||
'Adam'),
|
||||
])
|
||||
def test_checkpoint_zero_no_optimizer(tmpdir,
|
||||
zero_stage,
|
||||
use_cpu_offload,
|
||||
adam_optimizer):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -385,11 +387,11 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
|
||||
'Adam'),
|
||||
(2,
|
||||
True,
|
||||
'deepspeed_adam'),
|
||||
'Adam'),
|
||||
])
|
||||
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -459,11 +461,11 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
|
||||
'Adam'),
|
||||
(2,
|
||||
True,
|
||||
'deepspeed_adam'),
|
||||
'Adam'),
|
||||
])
|
||||
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
|
@ -1,16 +1,16 @@
|
||||
import argparse
|
||||
import torch
|
||||
import apex
|
||||
import time
|
||||
import numpy as np
|
||||
import pytest
|
||||
import copy
|
||||
|
||||
import deepspeed
|
||||
if not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed", allow_module_level=True)
|
||||
else:
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
|
||||
if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
|
||||
pytest.skip("cpu-adam is not compatible")
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
@ -32,6 +32,7 @@ def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
(1048576),
|
||||
]) # yapf: disable
|
||||
def test_cpu_adam_opt(model_size):
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
device = 'cpu'
|
||||
rng_state = torch.get_rng_state()
|
||||
param = torch.nn.Parameter(torch.randn(model_size, device=device))
|
||||
@ -42,7 +43,7 @@ def test_cpu_adam_opt(model_size):
|
||||
param2 = torch.nn.Parameter(param2_data)
|
||||
|
||||
optimizer1 = torch.optim.AdamW([param1])
|
||||
optimizer2 = apex.optimizers.FusedAdam([param2])
|
||||
optimizer2 = FusedAdam([param2])
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
for i in range(10):
|
||||
|
@ -16,8 +16,8 @@ import deepspeed
|
||||
|
||||
import sys
|
||||
|
||||
if not deepspeed.ops.__installed_ops__['transformer']:
|
||||
pytest.skip("transformer kernels are not installed", allow_module_level=True)
|
||||
#if not deepspeed.ops.__installed_ops__['transformer']:
|
||||
# pytest.skip("transformer kernels are not installed", allow_module_level=True)
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
@ -254,6 +254,7 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
|
||||
check_equal(base_grads, ds_grads, atol=atol, verbose=verbose)
|
||||
|
||||
|
||||
#test_backward[3-1024-120-16-24-True-True-0.05]
|
||||
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
|
||||
[
|
||||
(3,1024,120,16,24,True,False, 0.05),
|
||||
|
@ -16,8 +16,8 @@ import deepspeed
|
||||
|
||||
import sys
|
||||
|
||||
if not deepspeed.ops.__installed_ops__['transformer']:
|
||||
pytest.skip("transformer kernels are not installed", allow_module_level=True)
|
||||
#if not deepspeed.ops.__installed_ops__['transformer']:
|
||||
# pytest.skip("transformer kernels are not installed", allow_module_level=True)
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
|
@ -8,9 +8,6 @@ import numpy as np
|
||||
from common import distributed_test
|
||||
from simple_model import SimpleModel, args_from_dict
|
||||
|
||||
lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
|
||||
reason="lamb is not installed")
|
||||
|
||||
|
||||
def run_model_step(model, gradient_list):
|
||||
for value in gradient_list:
|
||||
@ -168,7 +165,6 @@ def test_fused_some_overflow(tmpdir):
|
||||
_test_fused_some_overflow(args)
|
||||
|
||||
|
||||
@lamb_available
|
||||
def test_unfused_no_overflow(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
@ -212,7 +208,6 @@ def test_unfused_no_overflow(tmpdir):
|
||||
_test_unfused_no_overflow(args)
|
||||
|
||||
|
||||
@lamb_available
|
||||
def test_unfused_all_overflow(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
@ -258,7 +253,6 @@ def test_unfused_all_overflow(tmpdir):
|
||||
_test_unfused_all_overflow(args)
|
||||
|
||||
|
||||
@lamb_available
|
||||
def test_unfused_some_overflow(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
|
@ -1,18 +1,21 @@
|
||||
import torch
|
||||
import apex
|
||||
import deepspeed
|
||||
import argparse
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from common import distributed_test
|
||||
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
|
||||
|
||||
lamb_available = pytest.mark.skipif(not deepspeed.ops.__installed_ops__['lamb'],
|
||||
reason="lamb is not installed")
|
||||
try:
|
||||
from apex import amp
|
||||
_amp_available = True
|
||||
except ImportError:
|
||||
_amp_available = False
|
||||
amp_available = pytest.mark.skip(_amp_available, reason="apex/amp is not installed")
|
||||
|
||||
|
||||
@lamb_available
|
||||
def test_lamb_fp32_grad_clip(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -48,7 +51,6 @@ def test_lamb_fp32_grad_clip(tmpdir):
|
||||
_test_lamb_fp32_grad_clip(args=args, model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
@lamb_available
|
||||
def test_lamb_fp16_basic(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -86,7 +88,6 @@ def test_lamb_fp16_basic(tmpdir):
|
||||
_test_lamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
@lamb_available
|
||||
def test_lamb_fp16_empty_grad(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -234,8 +235,8 @@ def test_adamw_fp16_empty_grad(tmpdir):
|
||||
True),
|
||||
])
|
||||
def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
# pytest.skip("cpu-adam is not installed")
|
||||
config_dict = {
|
||||
"train_batch_size": 1,
|
||||
"steps_per_print": 1,
|
||||
@ -302,8 +303,8 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo
|
||||
True),
|
||||
])
|
||||
def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
# pytest.skip("cpu-adam is not installed")
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"steps_per_print": 1,
|
||||
@ -402,8 +403,8 @@ def test_zero_static_scale_deprecated_format(tmpdir):
|
||||
True),
|
||||
])
|
||||
def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
# pytest.skip("cpu-adam is not installed")
|
||||
config_dict = {
|
||||
"train_batch_size": 4,
|
||||
"steps_per_print": 1,
|
||||
@ -442,8 +443,8 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
|
||||
True),
|
||||
])
|
||||
def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload):
|
||||
if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
pytest.skip("cpu-adam is not installed")
|
||||
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
|
||||
# pytest.skip("cpu-adam is not installed")
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
@ -489,6 +490,7 @@ def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload):
|
||||
_test_zero_empty_partition(args)
|
||||
|
||||
|
||||
@amp_available
|
||||
def test_adam_amp_basic(tmpdir):
|
||||
config_dict = {"train_batch_size": 1, "steps_per_print": 1, "amp": {"enabled": True}}
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
@ -514,7 +516,7 @@ def test_adam_amp_basic(tmpdir):
|
||||
_test_adam_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
@lamb_available
|
||||
@amp_available
|
||||
def test_lamb_amp_basic(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -552,6 +554,7 @@ def test_lamb_amp_basic(tmpdir):
|
||||
_test_lamb_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
@amp_available
|
||||
def test_adam_amp_o2(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -590,6 +593,7 @@ def test_adam_amp_o2(tmpdir):
|
||||
_test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
@amp_available
|
||||
def test_adam_amp_o2_empty_grad(tmpdir):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
@ -630,11 +634,11 @@ def test_adam_amp_o2_empty_grad(tmpdir):
|
||||
|
||||
@pytest.mark.parametrize('zero_stage, optimizer_constructor',
|
||||
[(1,
|
||||
apex.optimizers.FusedAdam),
|
||||
FusedAdam),
|
||||
(2,
|
||||
torch.optim.Adam),
|
||||
(2,
|
||||
apex.optimizers.FusedAdam)])
|
||||
FusedAdam)])
|
||||
def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_constructor):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
|
@ -6,9 +6,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.ops.op_builder import SparseAttnBuilder
|
||||
|
||||
if not deepspeed.ops.__installed_ops__['sparse-attn']:
|
||||
pytest.skip("cpu-adam is not installed", allow_module_level=True)
|
||||
if not deepspeed.ops.__compatible_ops__[SparseAttnBuilder.NAME]:
|
||||
pytest.skip("sparse attention op is not compatible on this system",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
def test_sparse_attention_module_availability():
|
||||
@ -236,7 +238,7 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo
|
||||
|
||||
|
||||
def _skip_on_cuda_compatability():
|
||||
pytest.skip("Skip these tests for now until we get our docker image fixed.")
|
||||
#pytest.skip("Skip these tests for now until we get our docker image fixed.")
|
||||
if torch.cuda.get_device_capability()[0] != 7:
|
||||
pytest.skip("needs compute capability 7; v100")
|
||||
cuda_major = int(torch.version.cuda.split('.')[0]) * 10
|
||||
|
1
third_party/apex
vendored
1
third_party/apex
vendored
Submodule third_party/apex deleted from 494f8ab3fc
1
version.txt
Normal file
1
version.txt
Normal file
@ -0,0 +1 @@
|
||||
0.3.1
|
Reference in New Issue
Block a user