Sparse attn + ops/runtime refactor + v0.3.0 (#343)
* Sparse attn + ops/runtime refactor + v0.3.0 Co-authored-by: Arash Ashari <arashari@microsoft.com> Co-authored-by: Arash Ashari <arashari@microsoft.com>
@ -20,7 +20,7 @@ AllowShortLoopsOnASingleLine: true
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
AlwaysBreakTemplateDeclarations: Yes
|
||||
AlwaysBreakTemplateDeclarations: true
|
||||
BinPackArguments: false
|
||||
BinPackParameters: false
|
||||
BraceWrapping:
|
||||
|
2
.gitignore
vendored
@ -8,7 +8,7 @@ deepspeed/git_version_info.py
|
||||
# Build + installation data
|
||||
build/
|
||||
dist/
|
||||
fused_lamb_*.so
|
||||
*.so
|
||||
deepspeed.egg-info/
|
||||
|
||||
# Website
|
||||
|
@ -8,7 +8,8 @@ RUN apt-get update && \
|
||||
software-properties-common \
|
||||
openssh-client openssh-server \
|
||||
pdsh curl sudo net-tools \
|
||||
vim iputils-ping wget
|
||||
vim iputils-ping wget \
|
||||
llvm-9-dev cmake
|
||||
|
||||
##############################################################################
|
||||
# Installation Latest Git
|
||||
@ -85,7 +86,7 @@ RUN mkdir -p ${STAGE_DIR} && \
|
||||
dpkg -i ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_TAG}_all.deb
|
||||
|
||||
##############################################################################
|
||||
## Ucomment and set SSH Daemon port
|
||||
## SSH daemon port inside container cannot conflict with host OS port
|
||||
###############################################################################
|
||||
ENV SSH_PORT=2222
|
||||
RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \
|
||||
|
@ -1,49 +1,83 @@
|
||||
|
||||
jobs:
|
||||
- job: Default
|
||||
- job: DeepSpeed_Tests
|
||||
timeoutInMinutes: 360
|
||||
pool:
|
||||
name: 'GPU_testing'
|
||||
name: 'DS_testing'
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
Python36:
|
||||
PyTorch12-CUDA100:
|
||||
python.version: '3.6'
|
||||
#Python35:
|
||||
# python.version: '3.5'
|
||||
#Python37:
|
||||
cuda.version: '10.0'
|
||||
pytorch.version: '1.2'
|
||||
torchvision.version: '0.4.0'
|
||||
runmodeltests: true
|
||||
#PyTorch15-CUDA101:
|
||||
# python.version: '3.7'
|
||||
#Python38:
|
||||
# python.version: '3.8'
|
||||
# cuda.version: '10.1'
|
||||
# pytorch.version: '1.5.0+cu101'
|
||||
# torchvision.version: '0.6.0+cu101'
|
||||
# runmodeltests: true
|
||||
##PyTorch15-CUDA102:
|
||||
# python.version: '3.7'
|
||||
# cuda.version: '10.2'
|
||||
# pytorch.version: '1.5'
|
||||
# torchvision.version: '0.6.1'
|
||||
# runmodeltests: true
|
||||
|
||||
variables:
|
||||
conda_env: 'ds_test_py$(python.version)_cuda$(cuda.version)_pytorch$(pytorch.version)'
|
||||
|
||||
steps:
|
||||
- task: UsePythonVersion@0
|
||||
inputs:
|
||||
versionSpec: '$(python.version)'
|
||||
addToPath: true
|
||||
architecture: 'x64'
|
||||
displayName: 'Use Python $(python.version)'
|
||||
# Unfortunately nvidia's nvcc_linux-64=<version> seems to install 10.1 regardless?
|
||||
# Most of this complexity is a workaround to get the compiler toolchain to match the
|
||||
# cudatoolkit runtime
|
||||
- script: |
|
||||
conda create --force --yes -n $(conda_env) python=$(python.version) cudatoolkit=$(cuda.version)
|
||||
source activate $(conda_env)
|
||||
conda install -q --yes conda
|
||||
conda install -q --yes pip
|
||||
conda install -q --yes gxx_linux-64
|
||||
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
|
||||
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'
|
||||
|
||||
# Manually install torch/torchvision first to enforce versioning.
|
||||
- script: |
|
||||
source activate $(conda_env)
|
||||
pip install --progress-bar=off torch==$(pytorch.version) torchvision==$(torchvision.version)
|
||||
#-f https://download.pytorch.org/whl/torch_stable.html
|
||||
./install.sh --local_only
|
||||
#python -I basic_install_test.py
|
||||
displayName: 'Install DeepSpeed'
|
||||
|
||||
- script: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --user -r requirements.txt
|
||||
./install.sh --pip_sudo
|
||||
displayName: 'Install dependencies'
|
||||
source activate $(conda_env)
|
||||
which python
|
||||
python --version
|
||||
which nvcc
|
||||
nvcc --version
|
||||
which deepspeed
|
||||
python -c "import torch; print('torch:', torch.__version__, torch)"
|
||||
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
|
||||
python -c "import deepspeed; print('deepspeed:', deepspeed.__version__)"
|
||||
displayName: 'Show environment'
|
||||
|
||||
|
||||
- script: |
|
||||
pre-commit run --all-files
|
||||
displayName: 'Formatting checks'
|
||||
|
||||
- script: |
|
||||
pytest --forked --verbose tests/unit/
|
||||
source activate $(conda_env)
|
||||
pytest --durations=0 --forked --verbose -x tests/unit/
|
||||
displayName: 'Unit tests'
|
||||
|
||||
- script: |
|
||||
source activate $(conda_env)
|
||||
ln -s /data/Megatron-LM/data DeepSpeedExamples/Megatron-LM/
|
||||
pip install --user -r DeepSpeedExamples/Megatron-LM/requirements.txt
|
||||
pip install --progress-bar=off -r DeepSpeedExamples/Megatron-LM/requirements.txt
|
||||
cd tests/model/
|
||||
pytest -s run_sanity_check.py
|
||||
rm -rf BingBertSquad/baseline
|
||||
rm -rf Megatron_GPT2/baseline
|
||||
pytest --durations=0 -s run_sanity_check.py
|
||||
condition: and(succeeded(), eq(variables['runmodeltests'], true))
|
||||
displayName: 'Model tests'
|
||||
|
||||
#BingBertSquad logs
|
||||
@ -52,35 +86,29 @@ jobs:
|
||||
targetPath: '$(Build.SourcesDirectory)/tests/model/BingBertSquad/test/'
|
||||
artifactName: BingBertSquad_logs
|
||||
displayName: 'BingBertSquad log uploads'
|
||||
condition: always()
|
||||
|
||||
# Megatron test logs
|
||||
#- task: PublishPipelineArtifact@1
|
||||
# inputs:
|
||||
# targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/test/'
|
||||
# artifactName: Megatron_GPT2_logs
|
||||
# displayName: 'Megatron GPT2 log uploads'
|
||||
# condition: always()
|
||||
|
||||
#- task: PublishPipelineArtifact@1
|
||||
# inputs:
|
||||
# targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/checkpoint_test_logs/'
|
||||
# artifactName: Megatron_GPT2_checkpoint_logs
|
||||
# displayName: 'Megatron GPT2 checkpoint log uploads'
|
||||
# condition: always()
|
||||
condition: eq(variables['runmodeltests'], true)
|
||||
|
||||
|
||||
#BingBert logs
|
||||
#- task: PublishPipelineArtifact@1
|
||||
# inputs:
|
||||
# targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/pretrain_test/'
|
||||
# artifactName: BingBert_pretrain_logs
|
||||
# displayName: 'BingBert pretrain logs'
|
||||
# condition: always()
|
||||
- job: Code_Quality_Checks
|
||||
pool:
|
||||
name: 'DS_testing'
|
||||
variables:
|
||||
conda_env: 'ds_codetest'
|
||||
|
||||
#- task: PublishPipelineArtifact@1
|
||||
# inputs:
|
||||
# targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/checkpoint_test_logs/'
|
||||
# artifactName: BingBert_checkpoint_logs
|
||||
# displayName: 'BingBert checkpoint logs'
|
||||
# condition: always()
|
||||
steps:
|
||||
- script: |
|
||||
conda create --force --yes -n $(conda_env) python=3.7
|
||||
source activate $(conda_env)
|
||||
displayName: 'Create code test environment'
|
||||
|
||||
- script: |
|
||||
source activate $(conda_env)
|
||||
pip install pre-commit
|
||||
pre-commit run --all-files
|
||||
displayName: 'Formatting checks'
|
||||
|
||||
- script: |
|
||||
source activate $(conda_env)
|
||||
pip install pylint
|
||||
pylint --exit-zero deepspeed/
|
||||
displayName: 'Code linter'
|
||||
|
@ -18,7 +18,7 @@ except Exception as err:
|
||||
raise err
|
||||
|
||||
try:
|
||||
fused_lamb = importlib.import_module('deepspeed_lamb_cuda')
|
||||
fused_lamb = importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
|
||||
print('deepspeed fused lamb kernels successfully installed')
|
||||
except Exception as err:
|
||||
raise err
|
||||
@ -30,7 +30,8 @@ except ImportError:
|
||||
print("using new-style apex")
|
||||
|
||||
try:
|
||||
ds_transformer = importlib.import_module('deepspeed_transformer_cuda')
|
||||
ds_transformer = importlib.import_module(
|
||||
'deepspeed.ops.transformer.transformer_cuda')
|
||||
print('deepspeed transformer kernels successfully installed')
|
||||
except Exception as err:
|
||||
raise err
|
||||
|
2
bin/ds
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from deepspeed.pt.deepspeed_run import main
|
||||
from deepspeed.launcher.runner import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
120
csrc/sparse_attention/utils.cpp
Normal file
@ -0,0 +1,120 @@
|
||||
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
|
||||
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
|
||||
|
||||
void segment_blocks(torch::Tensor layout,
|
||||
torch::Tensor idx,
|
||||
torch::Tensor scratch,
|
||||
int max_width,
|
||||
ret_t& ret)
|
||||
{
|
||||
size_t H = layout.size(0);
|
||||
size_t M = layout.size(1);
|
||||
size_t N = layout.size(2);
|
||||
torch::Tensor tmp = torch::zeros_like(layout);
|
||||
|
||||
auto _tmp = tmp.accessor<int, 3>();
|
||||
auto _layout = layout.accessor<int, 3>();
|
||||
auto _idx = idx.accessor<int, 3>();
|
||||
auto _scratch = scratch.accessor<int, 3>();
|
||||
std::vector<int> current(H, 0);
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
int v = _layout[h][m][n];
|
||||
if (v == 0) continue;
|
||||
int n_left = ii_left[max_width - 1];
|
||||
int m_top = ii_top[max_width - 1][n];
|
||||
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
|
||||
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
|
||||
int topleft = (m_top >= 0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
|
||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||
|
||||
// reset width if blocks cannot be
|
||||
// packed together (i.e., there's a 1 "in the middle")
|
||||
for (int nn = n_left + 1; nn < n; nn++)
|
||||
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) width = 1;
|
||||
_tmp[h][m][n] = width;
|
||||
|
||||
// update n_left ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++) ii_left[k] = ii_left[k + 1];
|
||||
ii_left[max_width - 1] = n;
|
||||
|
||||
// update ii_top ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++) ii_top[k][n] = ii_top[k + 1][n];
|
||||
ii_top[max_width - 1][n] = m;
|
||||
|
||||
// block is too small -- skip
|
||||
if (width != max_width) continue;
|
||||
|
||||
// retained blocks are set to zeros
|
||||
for (size_t km = 0; km < max_width; km++)
|
||||
for (size_t kn = 0; kn < max_width; kn++) {
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if (mm < 0 || nn < 0) continue;
|
||||
_layout[h][mm][nn] = 0;
|
||||
_tmp[h][mm][nn] = 0;
|
||||
_scratch[h][current[h]][0] = (int)h;
|
||||
_scratch[h][current[h]][1] = (int)mm;
|
||||
_scratch[h][current[h]][2] = (int)nn;
|
||||
_scratch[h][current[h]][3] = _idx[h][mm][nn];
|
||||
current[h]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<torch::Tensor> to_cat;
|
||||
for (size_t h = 0; h < H; h++)
|
||||
if (current[h] > 0) to_cat.push_back(scratch[h].slice(0, 0, current[h]));
|
||||
if (!to_cat.empty()) ret.push_back({max_width, torch::cat(to_cat)});
|
||||
}
|
||||
|
||||
ret_t sdd_segment(torch::Tensor layout, int start_width)
|
||||
{
|
||||
ret_t ret;
|
||||
|
||||
// block index
|
||||
torch::Tensor idx = torch::zeros_like(layout);
|
||||
int current = 0;
|
||||
size_t H = layout.size(0);
|
||||
size_t M = layout.size(1);
|
||||
size_t N = layout.size(2);
|
||||
auto _layout = layout.accessor<int, 3>();
|
||||
auto _idx = idx.accessor<int, 3>();
|
||||
for (size_t h = 0; h < H; h++)
|
||||
for (size_t m = 0; m < M; m++)
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
if (_layout[h][m][n] == 0) continue;
|
||||
_idx[h][m][n] = current++;
|
||||
}
|
||||
|
||||
// scratch memory
|
||||
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
|
||||
|
||||
for (int max_width = start_width; max_width > 0; max_width /= 2)
|
||||
segment_blocks(layout, idx, scratch, max_width, ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("sdd_segment", &sdd_segment, "SDD segmentation handler");
|
||||
}
|
@ -1,34 +1,48 @@
|
||||
'''
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
'''
|
||||
import sys
|
||||
import types
|
||||
|
||||
from deepspeed.pt.deepspeed_light import DeepSpeedLight
|
||||
from deepspeed.pt.deepspeed_light import ADAM_OPTIMIZER, LAMB_OPTIMIZER
|
||||
from deepspeed.pt.deepspeed_lr_schedules import add_tuning_arguments
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.pt.deepspeed_cuda import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
|
||||
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
|
||||
|
||||
import deepspeed.pt.deepspeed_checkpointing as checkpointing
|
||||
from deepspeed.runtime.engine import DeepSpeedEngine
|
||||
from deepspeed.runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
|
||||
from deepspeed.runtime.lr_schedules import add_tuning_arguments
|
||||
from deepspeed.runtime.config import DeepSpeedConfig
|
||||
from deepspeed.runtime.activation_checkpointing import checkpointing
|
||||
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
|
||||
from deepspeed.utils import logger
|
||||
|
||||
try:
|
||||
from deepspeed.git_version_info import git_hash, git_branch
|
||||
from deepspeed.git_version_info import version, git_hash, git_branch
|
||||
except ImportError:
|
||||
version = "0.0.0+unknown"
|
||||
git_hash = None
|
||||
git_branch = None
|
||||
|
||||
# Export version information
|
||||
__version_major__ = 0
|
||||
__version_minor__ = 2
|
||||
__version_patch__ = 0
|
||||
version, __version_tag__ = version.split('+')
|
||||
__version_major__ = int(version.split('.')[0])
|
||||
__version_minor__ = int(version.split('.')[1])
|
||||
__version_patch__ = int(version.split('.')[2])
|
||||
__version__ = '.'.join(
|
||||
map(str,
|
||||
[__version_major__,
|
||||
__version_minor__,
|
||||
__version_patch__]))
|
||||
__version__ = f"{__version__}+{__version_tag__}"
|
||||
__git_hash__ = git_hash
|
||||
__git_branch__ = git_branch
|
||||
|
||||
# Provide backwards compatability with old deepspeed.pt module structure, should hopefully not be used
|
||||
pt = types.ModuleType('pt', 'dummy pt module for backwards compatability')
|
||||
deepspeed = sys.modules[__name__]
|
||||
setattr(deepspeed, 'pt', pt)
|
||||
setattr(deepspeed.pt, 'deepspeed_utils', deepspeed.runtime.utils)
|
||||
sys.modules['deepspeed.pt'] = deepspeed.pt
|
||||
sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils
|
||||
setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config)
|
||||
sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config
|
||||
|
||||
|
||||
def initialize(args,
|
||||
model,
|
||||
@ -90,16 +104,16 @@ def initialize(args,
|
||||
__git_branch__),
|
||||
)
|
||||
|
||||
engine = DeepSpeedLight(args=args,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
model_parameters=model_parameters,
|
||||
training_data=training_data,
|
||||
lr_scheduler=lr_scheduler,
|
||||
mpu=mpu,
|
||||
dist_init_required=dist_init_required,
|
||||
collate_fn=collate_fn,
|
||||
config_params=config_params)
|
||||
engine = DeepSpeedEngine(args=args,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
model_parameters=model_parameters,
|
||||
training_data=training_data,
|
||||
lr_scheduler=lr_scheduler,
|
||||
mpu=mpu,
|
||||
dist_init_required=dist_init_required,
|
||||
collate_fn=collate_fn,
|
||||
config_params=config_params)
|
||||
|
||||
return_items = [
|
||||
engine,
|
||||
|
@ -10,7 +10,7 @@ import base64
|
||||
from collections import defaultdict
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
def parse_args():
|
@ -14,8 +14,8 @@ from copy import deepcopy
|
||||
|
||||
import torch.cuda
|
||||
|
||||
from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.runtime.constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
||||
from deepspeed.utils import logger
|
||||
|
||||
DLTS_HOSTFILE = "/job/hostfile"
|
||||
EXPORT_ENVS = ["NCCL", "PYTHON"]
|
||||
@ -285,7 +285,7 @@ def main(args=None):
|
||||
sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"deepspeed.pt.deepspeed_launch",
|
||||
"deepspeed.launcher.launch",
|
||||
"--world_info={}".format(world_info_base64),
|
||||
"--master_addr={}".format(args.master_addr),
|
||||
"--master_port={}".format(args.master_port)
|
||||
@ -328,7 +328,7 @@ def main(args=None):
|
||||
sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"deepspeed.pt.deepspeed_launch",
|
||||
"deepspeed.launcher.launch",
|
||||
'--world_info={}'.format(world_info_base64),
|
||||
"--node_rank=%n",
|
||||
"--master_addr={}".format(args.master_addr),
|
0
deepspeed/ops/__init__.py
Normal file
1
deepspeed/ops/lamb/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from deepspeed.ops.lamb.fused_lamb import FusedLamb
|
@ -3,46 +3,37 @@ Copyright 2019 The Microsoft DeepSpeed Team
|
||||
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
|
||||
|
||||
'''
|
||||
import types
|
||||
import torch
|
||||
import importlib
|
||||
import torch
|
||||
|
||||
|
||||
class FusedLamb(torch.optim.Optimizer):
|
||||
"""Implements LAMB algorithm. Currently GPU-only. Requires DeepSpeed adapted Apex to be installed via
|
||||
``python setup.py install --cuda_ext --cpp_ext``.
|
||||
"""Implements the LAMB algorithm. Currently GPU-only.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Tutorial
|
||||
|
||||
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
|
||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
|
||||
https://arxiv.org/abs/1904.00962
|
||||
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
bias_correction (bool, optional): bias correction (default: True)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square. (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability. (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
|
||||
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False) NOT SUPPORTED in FusedAdam!
|
||||
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
|
||||
adds eps to the bias-corrected second moment estimate before
|
||||
evaluating square root instead of adding it to the square root of
|
||||
second moment estimate as in the original paper. (default: False)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
max_grad_norm (float, optional): value used to clip global grad norm
|
||||
(default: 0.0)
|
||||
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
|
||||
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
|
||||
amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb!
|
||||
"""
|
||||
def __init__(self,
|
||||
params,
|
||||
@ -58,7 +49,14 @@ class FusedLamb(torch.optim.Optimizer):
|
||||
min_coeff=0.01,
|
||||
amsgrad=False):
|
||||
global fused_lamb_cuda
|
||||
fused_lamb_cuda = importlib.import_module("deepspeed_lamb_cuda")
|
||||
try:
|
||||
fused_lamb_cuda = importlib.import_module(
|
||||
"deepspeed.ops.lamb.fused_lamb_cuda")
|
||||
except ImportError as err:
|
||||
print(
|
||||
"Unable to import Lamb cuda extension, please build DeepSpeed with cuda/cpp extensions."
|
||||
)
|
||||
raise err
|
||||
|
||||
if amsgrad:
|
||||
raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
|
||||
@ -153,9 +151,7 @@ class FusedLamb(torch.optim.Optimizer):
|
||||
if grad is None:
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
|
||||
)
|
||||
raise RuntimeError('FusedLamb does not support sparse gradients')
|
||||
|
||||
state = self.state[p]
|
||||
|
6
deepspeed/ops/sparse_attention/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .sparsity_config import SparsityConfig, DenseSparsityConfig, FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig
|
||||
from .softmax import Softmax
|
||||
from .matmul import MatMul
|
||||
from .sparse_self_attention import SparseSelfAttention
|
||||
from .bert_sparse_self_attention import BertSparseSelfAttention
|
||||
from .sparse_attention_utils import SparseAttentionUtils
|
78
deepspeed/ops/sparse_attention/bert_sparse_self_attention.py
Executable file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig
|
||||
|
||||
|
||||
class BertSparseSelfAttention(nn.Module):
|
||||
"""Implements Sparse Self Attention layer of Bert model based on https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373
|
||||
|
||||
For more information please see, TODO DeepSpeed Sparse Transformer.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
# SparsityConfig parameters needs to be set accordingly
|
||||
sparsity_config=FixedSparsityConfig(num_heads=4)):
|
||||
"""Initialize the bert sparse self attention layer.
|
||||
|
||||
Note) you can use any of the provided sparsity configs or simply add yours!
|
||||
|
||||
Arguments:
|
||||
config: required: Bert model config
|
||||
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on FixedSparsityConfig class.
|
||||
"""
|
||||
|
||||
super(BertSparseSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size,
|
||||
config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.sparse_self_attention = SparseSelfAttention(sparsity_config)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
"""Applies forward phase of bert sparse self attention
|
||||
|
||||
Arguments:
|
||||
hidden_states: required: hidde_states tensor of the bert model
|
||||
attn_mask: required: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
|
||||
|
||||
Return:
|
||||
context_layer: a dense tensor containing attnetion context
|
||||
"""
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
context_layer = self.sparse_self_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
key_padding_mask=attention_mask)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer
|
729
deepspeed/ops/sparse_attention/matmul.py
Normal file
@ -0,0 +1,729 @@
|
||||
# DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
|
||||
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
|
||||
import importlib
|
||||
import warnings
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
warnings.warn("Unable to import triton, sparse attention will not be accessible")
|
||||
import torch
|
||||
import math
|
||||
from deepspeed.ops.sparse_attention.trsrc import matmul
|
||||
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
class _sparse_matmul(torch.autograd.Function):
|
||||
|
||||
sdd_cache = dict()
|
||||
dsd_cache = dict()
|
||||
dds_cache = dict()
|
||||
locks = dict()
|
||||
|
||||
# Given an array sizes representing reduction size for each
|
||||
# column of a block-mode matrix multiplication,
|
||||
# performs load-balancing to achieve more smaller reductions
|
||||
# between `seg_size` elements
|
||||
@staticmethod
|
||||
def load_balance(sizes, block):
|
||||
# segment size
|
||||
# heuristics taken from OpenAI blocksparse code
|
||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||
max_size = sizes.max()
|
||||
min_size = sizes[sizes != 0].min()
|
||||
#if max_size > min_size * 2.0:
|
||||
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
|
||||
#else:
|
||||
# seg_max = max_size
|
||||
seg_max = max_size
|
||||
seg_min = max(triton.cdiv(seg_max, 4), 4)
|
||||
# split reduction into segments
|
||||
div = sizes // seg_max
|
||||
rem = sizes % seg_max
|
||||
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
|
||||
width = packs.sum()
|
||||
segments = torch.empty(width, dtype=sizes.dtype)
|
||||
column = torch.empty_like(segments)
|
||||
lockid = torch.zeros_like(segments)
|
||||
maxid = torch.zeros_like(segments)
|
||||
nlocks = 0
|
||||
current = 0
|
||||
col_idx = 0
|
||||
for i in range(len(sizes)):
|
||||
d, r = div[i], rem[i]
|
||||
isempty = sizes[i] < seg_min
|
||||
last = current + d + (r >= seg_min) + isempty
|
||||
# column id
|
||||
column[current:last] = col_idx
|
||||
# lock id
|
||||
if d > 1 or (d == 1 and r >= seg_min):
|
||||
nlocks += 1
|
||||
lockid[current:last] = nlocks
|
||||
maxid[current:last] = last - current
|
||||
# segment size
|
||||
segments[current:current + d] = seg_max
|
||||
if r < seg_min and not isempty:
|
||||
segments[current + d - 1] += r
|
||||
if r >= seg_min or isempty:
|
||||
segments[current + d] = r
|
||||
current = last
|
||||
col_idx += 1
|
||||
offsets = torch.zeros_like(segments)
|
||||
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
||||
return segments, column, lockid, maxid, offsets
|
||||
|
||||
@staticmethod
|
||||
def get_locks(size, dev):
|
||||
if dev not in _sparse_matmul.locks or \
|
||||
size > _sparse_matmul.locks[dev].size(0):
|
||||
_sparse_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
|
||||
return _sparse_matmul.locks[dev]
|
||||
|
||||
##########################
|
||||
# SPARSE = DENSE x DENSE #
|
||||
##########################
|
||||
cpp_utils = importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils')
|
||||
sdd_segment = cpp_utils.sdd_segment
|
||||
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout, block, dtype, device):
|
||||
start_width = 64 // block
|
||||
segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
|
||||
luts, widths, packs = [], [], []
|
||||
for size, nnz in segmented:
|
||||
width = nnz.shape[0] // (size * size)
|
||||
h = nnz[:, 0]
|
||||
i = nnz[:, 1]
|
||||
j = nnz[:, 2]
|
||||
b = nnz[:, 3]
|
||||
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
|
||||
luts.append(lut.type(torch.int32).to(device))
|
||||
widths.append(width)
|
||||
packs.append(size)
|
||||
# create locks
|
||||
return luts, None, widths, packs
|
||||
|
||||
@staticmethod
|
||||
def _sdd_matmul(a,
|
||||
b,
|
||||
trans_a,
|
||||
trans_b,
|
||||
trans_c,
|
||||
spdims,
|
||||
block,
|
||||
luts,
|
||||
num_locks,
|
||||
widths,
|
||||
packs,
|
||||
bench,
|
||||
time):
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
is_16_multiple = AS3 % 16 == 0
|
||||
is_32_multiple = AS3 % 32 == 0
|
||||
is_64_multiple = AS3 % 64 == 0
|
||||
if not is_16_multiple:
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
# create kernel
|
||||
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
||||
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=a.device)
|
||||
for lut, width, pack in zip(luts, widths, packs):
|
||||
num_lock = 1
|
||||
key = (block,
|
||||
a.dtype,
|
||||
b.dtype,
|
||||
trans_a,
|
||||
trans_b,
|
||||
trans_c,
|
||||
pack,
|
||||
is_32_multiple,
|
||||
is_64_multiple)
|
||||
if key not in _sparse_matmul.sdd_cache:
|
||||
F32TK = [8, 16]
|
||||
F16TK = [16]
|
||||
F16TK += [32] if is_32_multiple else []
|
||||
F16TK += [64] if is_64_multiple else []
|
||||
TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]
|
||||
defines = {
|
||||
'TM': block * pack,
|
||||
'TN': block * pack,
|
||||
'TMN': block * block * pack * pack,
|
||||
'BLOCK': block,
|
||||
'TK': TK,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': '1' if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else '1',
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': 'ldc',
|
||||
'STRIDE_CN': '1',
|
||||
'SDD': True,
|
||||
'TZ': 1,
|
||||
'NAME': 'sdd_kernel'
|
||||
}
|
||||
_sparse_matmul.sdd_cache[key] = triton.kernel(matmul,
|
||||
defines=defines,
|
||||
num_warps=[1,
|
||||
2,
|
||||
4])
|
||||
#_sparse_matmul.sdd_cache[key] = triton.kernel(src, defines=defines, num_warps=[1, 2, 4])
|
||||
|
||||
kernel = _sparse_matmul.sdd_cache[key]
|
||||
# create output
|
||||
locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)
|
||||
# maximum grid size is 65535
|
||||
# so operation might be decomposed into multiple
|
||||
# kernel calls
|
||||
max_width = 49152
|
||||
total = 0 if bench else None
|
||||
for off_width in range(0, width, max_width):
|
||||
current = kernel(a,
|
||||
b,
|
||||
c,
|
||||
a.stride(2),
|
||||
b.stride(2),
|
||||
block,
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
c.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
AS2,
|
||||
AS2,
|
||||
AS3,
|
||||
off_width,
|
||||
lut,
|
||||
locks,
|
||||
num_lock,
|
||||
grid=lambda opt:
|
||||
[opt.d('TZ'),
|
||||
min(max_width,
|
||||
width - off_width),
|
||||
AS0],
|
||||
bench=bench)
|
||||
total = total + current if bench else None
|
||||
time[0] = total
|
||||
# save for backward pass
|
||||
return c
|
||||
|
||||
##########################
|
||||
# DENSE = DENSE x SPARSE #
|
||||
##########################
|
||||
|
||||
# Given a binary layout of 0s and 1s,
|
||||
# Construct look-up table for efficient execution on GPUs
|
||||
@staticmethod
|
||||
def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
|
||||
# load-balancing
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
segments = _empty.clone()
|
||||
column = _empty.clone()
|
||||
depth = _empty.clone()
|
||||
lockid = _empty.clone()
|
||||
maxid = _empty.clone()
|
||||
offsets = _empty.clone()
|
||||
current_offset = 0
|
||||
current_maxid = 0
|
||||
for z in range(layout.size(0)):
|
||||
if trans:
|
||||
sizes = torch.sum(layout[z, :, :], 1)
|
||||
else:
|
||||
sizes = torch.sum(layout[z, :, :], 0)
|
||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _sparse_matmul.load_balance(sizes, block)
|
||||
z_depth = z * torch.ones_like(z_segments)
|
||||
z_lockid[z_lockid > 0] += current_maxid
|
||||
current_maxid = z_lockid.max()
|
||||
# concatenate depth
|
||||
segments = torch.cat((segments, z_segments))
|
||||
column = torch.cat((column, z_column))
|
||||
depth = torch.cat((depth, z_depth))
|
||||
maxid = torch.cat((maxid, z_maxid))
|
||||
offsets = torch.cat((offsets, current_offset + z_offsets))
|
||||
lockid = torch.cat((lockid, z_lockid))
|
||||
current_offset += layout[z, :, :].sum()
|
||||
segments *= step
|
||||
# pointer increments
|
||||
if trans:
|
||||
nnz = layout.nonzero()
|
||||
else:
|
||||
nnz = layout.transpose(1, 2).nonzero()
|
||||
num_blocks = nnz.size(0)
|
||||
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
|
||||
idx = transform(nnz[:, 2] * block)
|
||||
xincs = idx.clone()
|
||||
xincs[1:] -= idx[:-1]
|
||||
# divide block into multiple steps
|
||||
div = block // step
|
||||
xincs = xincs.view(-1, 1).repeat(1, div)
|
||||
xincs[:, 1:] = step
|
||||
xincs[:, 0] -= (div - 1) * step
|
||||
# first increment for each reduction is actually the offset
|
||||
xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
|
||||
xincs = xincs.view(-1)
|
||||
# block-mode input increments
|
||||
if trans:
|
||||
widx = torch.arange(num_blocks)
|
||||
else:
|
||||
widx = _empty.clone()
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
||||
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
widx = widx
|
||||
wincs = widx * block * block
|
||||
wincs[1:] -= widx[:-1] * block * block
|
||||
wincs = wincs.view(-1, 1).repeat(1, div)
|
||||
if trans:
|
||||
wincs[:, 1:] = step
|
||||
wincs[:, 0] -= (div - 1) * step
|
||||
else:
|
||||
wincs[:, 1:] = step * block
|
||||
wincs[:, 0] -= (div - 1) * step * block
|
||||
wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
|
||||
wincs = wincs.view(-1)
|
||||
# adjust offset and segment size
|
||||
offsets *= 2 * div
|
||||
segments *= div
|
||||
# create header
|
||||
width = column.size(0)
|
||||
offsets += 6 * width
|
||||
header = torch.stack((offsets,
|
||||
segments,
|
||||
column,
|
||||
depth,
|
||||
lockid,
|
||||
maxid),
|
||||
dim=1).view(-1).contiguous()
|
||||
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
|
||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
# create locks
|
||||
num_locks = max(1, lockid.max())
|
||||
return lut, num_locks, width, None
|
||||
|
||||
@staticmethod
|
||||
def _dds_matmul(a,
|
||||
b,
|
||||
trans_a,
|
||||
trans_b,
|
||||
trans_c,
|
||||
spdims,
|
||||
block,
|
||||
lut,
|
||||
num_locks,
|
||||
width,
|
||||
packs,
|
||||
bench,
|
||||
time):
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = spdims[0]
|
||||
BS1 = block * spdims[2 if trans_b else 1]
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _sparse_matmul.dds_cache:
|
||||
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
|
||||
TK = [8] if dtype == torch.float32 else [16]
|
||||
defines = {
|
||||
'TM': TM,
|
||||
'TN': block,
|
||||
'TK': TK,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else 1,
|
||||
'STRIDE_BN': block if trans_b else 1,
|
||||
'STRIDE_BK': 1 if trans_b else block,
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dds_kernel',
|
||||
'DDS': True
|
||||
}
|
||||
_sparse_matmul.dds_cache[key] = triton.kernel(matmul,
|
||||
defines=defines,
|
||||
num_warps=[4])
|
||||
#_sparse_matmul.dds_cache[key] = triton.kernel(src, defines=defines, num_warps=[4])
|
||||
kernel = _sparse_matmul.dds_cache[key]
|
||||
# output
|
||||
CS0 = AS0
|
||||
CS1 = AS1
|
||||
CS2 = BS2 if trans_c else AS2
|
||||
CS3 = AS2 if trans_c else BS2
|
||||
locks = _sparse_matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
time[0] = kernel(a,
|
||||
b,
|
||||
c,
|
||||
a.stride(2),
|
||||
block,
|
||||
c.stride(2),
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
c.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(1),
|
||||
c.stride(1),
|
||||
AS2,
|
||||
BS2,
|
||||
0,
|
||||
0,
|
||||
lut,
|
||||
locks,
|
||||
num_locks,
|
||||
grid=lambda opt: [width,
|
||||
triton.cdiv(AS2,
|
||||
opt.d('TM')),
|
||||
AS0],
|
||||
bench=bench)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def _dsd_matmul(a,
|
||||
b,
|
||||
trans_a,
|
||||
trans_b,
|
||||
trans_c,
|
||||
spdims,
|
||||
block,
|
||||
lut,
|
||||
num_locks,
|
||||
width,
|
||||
packs,
|
||||
bench,
|
||||
time):
|
||||
# shapes / dtypes
|
||||
AS0 = spdims[0]
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
AS2 = block * spdims[1 if trans_a else 2]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _sparse_matmul.dsd_cache:
|
||||
TN = [64, 128] if dtype == torch.float32 else [64, 128, 256]
|
||||
TK = [8] if dtype == torch.float32 else [16]
|
||||
defines = {
|
||||
'TM': block,
|
||||
'TN': TN,
|
||||
'TK': TK,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else block,
|
||||
'STRIDE_AK': block if trans_a else 1,
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dsd_kernel',
|
||||
'DSD': True
|
||||
}
|
||||
_sparse_matmul.dsd_cache[key] = triton.kernel(matmul,
|
||||
defines=defines,
|
||||
num_warps=[4])
|
||||
#_sparse_matmul.dsd_cache[key] = triton.kernel(src, defines=defines, num_warps=[4])
|
||||
kernel = _sparse_matmul.dsd_cache[key]
|
||||
# output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
CS2 = BS3 if trans_c else AS1
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
time[0] = kernel(a,
|
||||
b,
|
||||
c,
|
||||
block,
|
||||
b.stride(2),
|
||||
c.stride(2),
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
c.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(1),
|
||||
c.stride(1),
|
||||
BS3,
|
||||
AS1,
|
||||
0,
|
||||
0,
|
||||
lut,
|
||||
locks,
|
||||
num_locks,
|
||||
grid=lambda opt: [width,
|
||||
triton.cdiv(BS3,
|
||||
opt.d('TN')),
|
||||
BS0],
|
||||
bench=bench)
|
||||
return c
|
||||
|
||||
fn = {
|
||||
'sdd': _sdd_matmul.__get__(object),
|
||||
'dsd': _dsd_matmul.__get__(object),
|
||||
'dds': _dds_matmul.__get__(object)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
a,
|
||||
b,
|
||||
trans_a,
|
||||
trans_b,
|
||||
trans_c,
|
||||
mode,
|
||||
spdims,
|
||||
block,
|
||||
c_lut,
|
||||
c_num_locks,
|
||||
c_width,
|
||||
c_packs,
|
||||
c_bench,
|
||||
c_time,
|
||||
da_lut,
|
||||
da_num_locks,
|
||||
da_width,
|
||||
da_packs,
|
||||
da_bench,
|
||||
da_time,
|
||||
db_lut,
|
||||
db_num_locks,
|
||||
db_width,
|
||||
db_packs,
|
||||
db_bench,
|
||||
db_time):
|
||||
c = _sparse_matmul.fn[mode](a,
|
||||
b,
|
||||
trans_a,
|
||||
trans_b,
|
||||
trans_c,
|
||||
spdims,
|
||||
block,
|
||||
c_lut,
|
||||
c_num_locks,
|
||||
c_width,
|
||||
c_packs,
|
||||
c_bench,
|
||||
c_time)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.da_num_locks = da_num_locks
|
||||
ctx.da_lut = da_lut
|
||||
ctx.da_width = da_width
|
||||
ctx.da_packs = da_packs
|
||||
ctx.da_bench = da_bench
|
||||
ctx.da_time = da_time
|
||||
ctx.db_lut = db_lut
|
||||
ctx.db_num_locks = db_num_locks
|
||||
ctx.db_width = db_width
|
||||
ctx.db_bench = db_bench
|
||||
ctx.db_packs = db_packs
|
||||
ctx.db_time = db_time
|
||||
ctx.mode = mode
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.trans_a = trans_a
|
||||
ctx.trans_b = trans_b
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
# saved for backward
|
||||
a, b = ctx.saved_tensors
|
||||
mode = ctx.mode
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _sparse_matmul.fn[mode_da](dc,
|
||||
b,
|
||||
False,
|
||||
not ctx.trans_b,
|
||||
ctx.trans_a,
|
||||
ctx.spdims,
|
||||
ctx.block,
|
||||
ctx.da_lut,
|
||||
ctx.da_num_locks,
|
||||
ctx.da_width,
|
||||
ctx.da_packs,
|
||||
ctx.da_bench,
|
||||
ctx.da_time)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _sparse_matmul.fn[mode_db](a,
|
||||
dc,
|
||||
not ctx.trans_a,
|
||||
False,
|
||||
ctx.trans_b,
|
||||
ctx.spdims,
|
||||
ctx.block,
|
||||
ctx.db_lut,
|
||||
ctx.db_num_locks,
|
||||
ctx.db_width,
|
||||
ctx.db_packs,
|
||||
ctx.db_bench,
|
||||
ctx.db_time)
|
||||
return da, db, None, None, None,\
|
||||
None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None
|
||||
|
||||
|
||||
class MatMul:
|
||||
"""Block-Sparse MatMul class; this class handles three types of matrix-multiplication:
|
||||
- sparse = dense X dense
|
||||
- dense = sparse X dense
|
||||
- dense = dense X sparse
|
||||
|
||||
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
|
||||
"""
|
||||
def make_lut(self, dtype, device):
|
||||
"""Generates the sparsity layout/s used in block-sparse matmul
|
||||
"""
|
||||
key = (dtype, device)
|
||||
if key in self.lut_cache:
|
||||
return self.lut_cache[key]
|
||||
# C look-up table
|
||||
layout, block = self.layout, self.block
|
||||
step = 8 if dtype == torch.float32 else 16
|
||||
if self.mode == 'sdd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dsd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
|
||||
# DA look-up table
|
||||
if self.mode == 'sdd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, True, device)
|
||||
elif self.mode == 'dsd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dds':
|
||||
da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
||||
# DB look-up table
|
||||
if self.mode == 'sdd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, False, device)
|
||||
elif self.mode == 'dsd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
return self.lut_cache[key]
|
||||
|
||||
def __init__(self, layout, block, mode, trans_a=False, trans_b=False, bench=False):
|
||||
"""Initialize the Block-Sparse MatMul class.
|
||||
|
||||
Arguments:
|
||||
layout: required: sparsity layout tensor
|
||||
block: required: an integer determining the block size.
|
||||
mode: required: a string determining type of matmul; ('sdd') sparse = dense X dense, ('dsd') dense = sparse X dense, ('dds') dense = dense X sparse
|
||||
trans_a: optional: a boolean determining if multiplication needs to be applied on transpose of input a; default is false
|
||||
trans_b: optional: a boolean determining if multiplication needs to be applied on transpose of input b; default is false
|
||||
bench: optional: set if you want to do benchmarking
|
||||
"""
|
||||
|
||||
if mode not in ['sdd', 'dsd', 'dds']:
|
||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||
# look-up table cache
|
||||
self.lut_cache = dict()
|
||||
# attributes
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.mode = mode
|
||||
self.spdims = layout.shape
|
||||
self.block = block
|
||||
self.layout = layout
|
||||
# timings
|
||||
self.bench = bench
|
||||
self.time_c = None
|
||||
self.time_da = None
|
||||
self.time_db = None
|
||||
|
||||
# pad shapes of a tensor to make it
|
||||
# compatible with kernel calls
|
||||
@staticmethod
|
||||
def _pad_shape(x, is_sparse):
|
||||
max_dim = 3 if is_sparse else 4
|
||||
for i in range(max_dim - x.dim()):
|
||||
x = x.unsqueeze(0)
|
||||
return x
|
||||
|
||||
def __call__(self, a, b):
|
||||
"""Applies Block-Sparse MatMul.
|
||||
|
||||
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
|
||||
|
||||
Arguments:
|
||||
a: required: a dense/block-sparse tensor; first input of mat-mul
|
||||
b: required: a dense/block-sparse tensor; second input of mat-mul
|
||||
|
||||
Return:
|
||||
c: a dense/block-sparse tensor result of a X b
|
||||
"""
|
||||
|
||||
|
||||
c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||
# timings
|
||||
time_c = [None]
|
||||
time_da = [None]
|
||||
time_db = [None]
|
||||
# pad shapes with ones
|
||||
a = MatMul._pad_shape(a, self.mode == 'dsd')
|
||||
b = MatMul._pad_shape(b, self.mode == 'dds')
|
||||
# execute
|
||||
c = _sparse_matmul.apply(a,
|
||||
b,
|
||||
self.trans_a,
|
||||
self.trans_b,
|
||||
False,
|
||||
self.mode,
|
||||
self.spdims,
|
||||
self.block,
|
||||
c_lut,
|
||||
c_num_locks,
|
||||
c_width,
|
||||
c_packs,
|
||||
self.bench,
|
||||
time_c,
|
||||
da_lut,
|
||||
da_num_locks,
|
||||
da_width,
|
||||
da_packs,
|
||||
self.bench,
|
||||
time_da,
|
||||
db_lut,
|
||||
db_num_locks,
|
||||
db_width,
|
||||
db_packs,
|
||||
self.bench,
|
||||
time_db)
|
||||
self.time_c = time_c[0]
|
||||
self.time_da = time_da[0]
|
||||
self.time_db = time_db[0]
|
||||
return c
|
292
deepspeed/ops/sparse_attention/softmax.py
Normal file
@ -0,0 +1,292 @@
|
||||
# DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
|
||||
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
|
||||
|
||||
import warnings
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
warnings.warn("Unable to import triton, sparse attention will not be accessible")
|
||||
import torch
|
||||
import math
|
||||
from deepspeed.ops.sparse_attention.trsrc import softmax_fwd, softmax_bwd
|
||||
|
||||
fwd_kernels = dict()
|
||||
bwd_kernels = dict()
|
||||
|
||||
|
||||
class _sparse_softmax(torch.autograd.Function):
|
||||
|
||||
bwd_kernels = dict()
|
||||
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
sizes = _empty.clone()
|
||||
# sizes along rows
|
||||
for h in range(layout.shape[0]):
|
||||
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||
# offsets in block format
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# block indices
|
||||
idx = torch.arange(layout.sum())
|
||||
head = layout.nonzero()[:, 0]
|
||||
rows = layout.nonzero()[:, 1]
|
||||
columns = layout.nonzero()[:, 2]
|
||||
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
||||
# construct look-up table
|
||||
offsets = offsets * 4 + 2 * sizes.numel()
|
||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||
lut = torch.cat((header, core)).type(torch.int32).to(device)
|
||||
return lut, int(sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def make_kernel(cache,
|
||||
src,
|
||||
max_k,
|
||||
dtype,
|
||||
block,
|
||||
apply_scale,
|
||||
apply_rpe,
|
||||
apply_kp_mask,
|
||||
apply_attn_mask,
|
||||
kp_mask_mode,
|
||||
attn_mask_mode):
|
||||
if max_k >= 32768:
|
||||
raise NotImplementedError('Reductions larger than 32768 elements '\
|
||||
'are not yet implemented')
|
||||
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
|
||||
pad = num_warps * 32 * 2
|
||||
TN = (int(max_k) + pad - 1) // pad * pad
|
||||
# just-in-time compile kernel
|
||||
key = (block,
|
||||
dtype,
|
||||
num_warps,
|
||||
TN,
|
||||
apply_scale,
|
||||
apply_rpe,
|
||||
apply_kp_mask,
|
||||
apply_attn_mask,
|
||||
kp_mask_mode,
|
||||
attn_mask_mode)
|
||||
if key not in cache:
|
||||
defines = {
|
||||
'TM': [1],
|
||||
'TN': [TN],
|
||||
'TYPE': dtype,
|
||||
'BLOCK': block,
|
||||
'INFINITY': {
|
||||
torch.float32: 'F32_INFINITY',
|
||||
torch.float16: 'F16_INFINITY'
|
||||
}[dtype]
|
||||
}
|
||||
if apply_scale:
|
||||
defines['APPLY_SCALE'] = True
|
||||
if apply_rpe:
|
||||
defines['APPLY_RPE'] = True
|
||||
if apply_kp_mask:
|
||||
defines['APPLY_KP_MASK'] = True
|
||||
if kp_mask_mode == 'mul':
|
||||
defines['KP_MASK_MUL'] = True
|
||||
if apply_attn_mask:
|
||||
defines['APPLY_ATTN_MASK'] = True
|
||||
if attn_mask_mode == 'mul':
|
||||
defines['ATTN_MASK_MUL'] = True
|
||||
kernel = triton.kernel(src, defines=defines, num_warps=[num_warps])
|
||||
cache[key] = kernel
|
||||
return cache[key]
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x,
|
||||
scale,
|
||||
rpe,
|
||||
key_padding_mask,
|
||||
attn_mask,
|
||||
kp_mask_mode,
|
||||
attn_mask_mode,
|
||||
spdims,
|
||||
block,
|
||||
lut,
|
||||
num_blocks,
|
||||
maxlut,
|
||||
bench,
|
||||
time):
|
||||
apply_scale = False if scale == 1.0 else True
|
||||
|
||||
# handle None rpe
|
||||
if rpe is None:
|
||||
apply_rpe = False
|
||||
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
|
||||
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_rpe = True
|
||||
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
|
||||
|
||||
# handle None key_padding_mask
|
||||
if key_padding_mask is None:
|
||||
apply_kp_mask = False
|
||||
stride_zkpm = 0
|
||||
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_kp_mask = True
|
||||
stride_zkpm = key_padding_mask.stride(0)
|
||||
|
||||
# handle None attention_mask
|
||||
if attn_mask is None:
|
||||
apply_attn_mask = False
|
||||
stride_zattnm = 0
|
||||
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_attn_mask = True
|
||||
stride_zattnm = attn_mask.stride(0)
|
||||
|
||||
# run kernel
|
||||
kernel = _sparse_softmax.make_kernel(fwd_kernels,
|
||||
softmax_fwd,
|
||||
maxlut * block,
|
||||
x.dtype,
|
||||
block,
|
||||
apply_scale,
|
||||
apply_rpe,
|
||||
apply_kp_mask,
|
||||
apply_attn_mask,
|
||||
kp_mask_mode,
|
||||
attn_mask_mode)
|
||||
M = x.shape[0]
|
||||
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.d('TM')), M]
|
||||
|
||||
# run kernel
|
||||
time[0] = kernel(x, scale, lut, rpe, key_padding_mask, attn_mask,\
|
||||
num_blocks, maxlut,\
|
||||
x.stride(0),\
|
||||
stride_zrpe, stride_hrpe, stride_srpe,\
|
||||
stride_zkpm, stride_zattnm,\
|
||||
grid=grid, bench=bench)
|
||||
# save to context
|
||||
ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(x, lut)
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.maxlut = maxlut
|
||||
ctx.scale = scale
|
||||
ctx.apply_scale = apply_scale
|
||||
ctx.apply_rpe = apply_rpe
|
||||
ctx.apply_kp_mask = apply_kp_mask
|
||||
ctx.apply_attn_mask = apply_attn_mask
|
||||
ctx.kp_mask_mode = kp_mask_mode
|
||||
ctx.attn_mask_mode = attn_mask_mode
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dx):
|
||||
# retrieve from context
|
||||
x, lut = ctx.saved_tensors
|
||||
# run kernel
|
||||
kernel = _sparse_softmax.make_kernel(bwd_kernels,
|
||||
softmax_bwd,
|
||||
ctx.maxlut * ctx.block,
|
||||
x.dtype,
|
||||
ctx.block,
|
||||
ctx.apply_scale,
|
||||
ctx.apply_rpe,
|
||||
ctx.apply_kp_mask,
|
||||
ctx.apply_attn_mask,
|
||||
ctx.kp_mask_mode,
|
||||
ctx.attn_mask_mode)
|
||||
M = x.shape[0]
|
||||
grid = lambda opt: [
|
||||
triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block,
|
||||
opt.d('TM')),
|
||||
M
|
||||
]
|
||||
kernel(x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
|
||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Softmax:
|
||||
"""Block-Sparse Softmax class; this class computes softmax on a block sparse matrix. It is also able to apply either/all of the following masks:
|
||||
- relative position embedding
|
||||
- key padding mask
|
||||
- attention mask
|
||||
|
||||
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
|
||||
"""
|
||||
|
||||
sparse_softmax = _sparse_softmax.apply
|
||||
|
||||
def make_lut(self, device):
|
||||
"""Generates the sparsity layout used in block-sparse softmax
|
||||
"""
|
||||
key = (device, )
|
||||
if key not in self.lut_cache:
|
||||
self.lut_cache[key] = _sparse_softmax.make_lut(self.layout,
|
||||
self.block,
|
||||
device)
|
||||
return self.lut_cache[key]
|
||||
|
||||
def __init__(self, layout, block, bench=False):
|
||||
"""Initialize the Block-Sparse Softmax class.
|
||||
|
||||
Arguments:
|
||||
layout: required: sparsity layout tensor
|
||||
block: required: an integer determining the block size.
|
||||
bench: optional: set if you want to do benchmarking
|
||||
"""
|
||||
|
||||
self.num_blocks = layout.sum()
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
self.block = block
|
||||
self.bench = bench
|
||||
self.lut_cache = dict()
|
||||
|
||||
def __call__(self,
|
||||
x,
|
||||
scale=1.,
|
||||
rpe=None,
|
||||
key_padding_mask=None,
|
||||
attn_mask=None,
|
||||
key_padding_mask_mode='add',
|
||||
attn_mask_mode='add'):
|
||||
"""Applies softmax on a Block-Sparse input tensor.
|
||||
|
||||
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
|
||||
|
||||
Arguments:
|
||||
x: required: a block-sparse tensor that softmax is applied on it; computation will be in place and result will be returned in the same tensor
|
||||
scale: optional: a float value; x values will be multiplied by this value before normalization. Default value is 1.0.
|
||||
rpe: optional: a tensor same dimension as x that is used as relative position embedding
|
||||
key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
|
||||
attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
|
||||
key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
|
||||
attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
|
||||
|
||||
Return:
|
||||
x: a block-sparse tensor contains normalized input x using softmax; and masks applied if given
|
||||
"""
|
||||
|
||||
time_y = [None]
|
||||
if rpe is not None and rpe.dtype != x.dtype:
|
||||
raise ValueError('relative position embedding must be %s' % x.dtype)
|
||||
if attn_mask is not None and attn_mask.dtype != x.dtype:
|
||||
raise ValueError('Attention mask must be %s' % x.dtype)
|
||||
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
||||
raise ValueError('Key padding mask must be %s' % x.dtype)
|
||||
lut, maxlut = self.make_lut(x.device)
|
||||
x = Softmax.sparse_softmax(x,
|
||||
scale,
|
||||
rpe,
|
||||
key_padding_mask,
|
||||
attn_mask,
|
||||
key_padding_mask_mode,
|
||||
attn_mask_mode,
|
||||
self.spdims,
|
||||
self.block,
|
||||
lut,
|
||||
self.num_blocks,
|
||||
maxlut,
|
||||
self.bench,
|
||||
time_y)
|
||||
self.time_y = time_y[0]
|
||||
return x
|
225
deepspeed/ops/sparse_attention/sparse_attention_utils.py
Normal file
@ -0,0 +1,225 @@
|
||||
"""
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from deepspeed.ops.sparse_attention import BertSparseSelfAttention, SparsityConfig
|
||||
'''
|
||||
This file contains few utility functions to handle adapting pretrained model with sparse self-attention module.
|
||||
'''
|
||||
|
||||
|
||||
class SparseAttentionUtils:
|
||||
"""This class provides some utility functions that are use integrating sparse attention into transformer models.
|
||||
Such utilities include extending position embeddings, replacing current self-attention layer with sparse attention, padding sequences to multiple of block size, etc.
|
||||
|
||||
"""
|
||||
@staticmethod
|
||||
def extend_position_embedding(model, max_position):
|
||||
"""This function extends the position embedding weights of a model loaded from a checkpoint.
|
||||
It assumes the new max position is bigger than the original max length.
|
||||
|
||||
Arguments:
|
||||
model: required: a transformer model
|
||||
max_position: required: an integer determining new position embedding size
|
||||
Return:
|
||||
model: updated model; in which position embedding weights have been extended based on new size
|
||||
"""
|
||||
|
||||
if hasattr(model, 'bert'):
|
||||
original_max_position = model.bert.embeddings.position_embeddings.weight.size(
|
||||
0)
|
||||
assert max_position > original_max_position
|
||||
extend_multiples = max(1, max_position // original_max_position)
|
||||
model.bert.embeddings.position_embeddings.weight.data = model.bert.embeddings.position_embeddings.weight.repeat(
|
||||
extend_multiples,
|
||||
1)
|
||||
elif hasattr(model, 'roberta'):
|
||||
# RoBERTa has positions 0 & 1 reserved, so embedding size is max position + 2
|
||||
original_max_position, embed_size = model.roberta.embeddings.position_embeddings.weight.shape
|
||||
original_max_position -= 2
|
||||
extend_multiples = max(1, max_position // original_max_position)
|
||||
assert max_position > original_max_position
|
||||
max_position += 2
|
||||
extended_position_embedding = model.roberta.embeddings.position_embeddings.weight.new_empty(
|
||||
max_position,
|
||||
embed_size)
|
||||
k = 2
|
||||
for i in range(extend_multiples):
|
||||
extended_position_embedding[k:(
|
||||
k + original_max_position
|
||||
)] = model.roberta.embeddings.position_embeddings.weight[2:]
|
||||
k += original_max_position
|
||||
model.roberta.embeddings.position_embeddings.weight.data = extended_position_embedding
|
||||
else:
|
||||
raise ValueError(
|
||||
'Please extend \"extend_position_embedding\" function to support your model type. It currently only supports \"bert\" & \"roberta\"!'
|
||||
)
|
||||
|
||||
model.config.max_position_embeddings = max_position
|
||||
print(
|
||||
f'Extended position embeddings to {original_max_position * extend_multiples}'
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_tokenizer_model_max_length(tokenizer, max_position):
|
||||
"""This function updates the position embedding length of a tokenizer to a new max position.
|
||||
|
||||
Arguments:
|
||||
tokenizer: required: a transformer tokenizer
|
||||
max_position: required: an integer determining new position embedding size
|
||||
Return:
|
||||
tokenizer: updated tokenizer; in which model maximum length has been extended based on new size
|
||||
"""
|
||||
|
||||
tokenizer.model_max_length = max_position
|
||||
tokenizer.init_kwargs['model_max_length'] = max_position
|
||||
print(f'updated tokenizer model max imum length to {max_position}')
|
||||
|
||||
return tokenizer
|
||||
|
||||
@staticmethod
|
||||
def replace_model_self_attention_with_sparse_self_attention(
|
||||
model,
|
||||
max_position,
|
||||
# SparsityConfig parameters needs to be set accordingly
|
||||
sparsity_config=SparsityConfig(num_heads=4)):
|
||||
"""This function replaces the self attention layers in model encoder with sparse self attention.
|
||||
It currently supports bert and roberta model and can be easily extended to any other models following similar steps here.
|
||||
For sparsityConfig, refer to the config class.
|
||||
|
||||
Arguments:
|
||||
model: required: a transformer model
|
||||
max_position: required: an integer determining new position embedding size
|
||||
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class
|
||||
|
||||
Return:
|
||||
model: updated model; in which self attention layer has been repleaced with DeepSpeed Sparse Self Attention layer.
|
||||
"""
|
||||
|
||||
if hasattr(model, 'bert'):
|
||||
model.config.max_position_embeddings = max_position
|
||||
replace_self_attention_layer_with_sparse_self_attention_layer(
|
||||
model.config,
|
||||
model.bert.encoder.layer,
|
||||
sparsity_config)
|
||||
elif hasattr(model, 'roberta'):
|
||||
model.config.max_position_embeddings = max_position + 2
|
||||
replace_self_attention_layer_with_sparse_self_attention_layer(
|
||||
model.config,
|
||||
model.roberta.encoder.layer,
|
||||
sparsity_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Please extend \"update_model_self_attention_to_sparse_self_attention\" function to support \
|
||||
your model type. It currently only supports \"bert\" & \"roberta\"!'
|
||||
)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def replace_self_attention_layer_with_sparse_self_attention_layer(
|
||||
config,
|
||||
layers,
|
||||
# SparsityConfig parameters needs to be set accordingly
|
||||
sparsity_config=SparsityConfig(num_heads=4)):
|
||||
"""This function replaces the self attention layers in attention layer with sparse self attention.
|
||||
For sparsityConfig, refer to the config class.
|
||||
|
||||
Arguments:
|
||||
config: required: transformer model config
|
||||
layers: required: transformer model attention layers
|
||||
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class
|
||||
|
||||
Return:
|
||||
layers: updated attention layers; in which self attention layers have been repleaced with DeepSpeed Sparse Self Attention layer.
|
||||
"""
|
||||
|
||||
for layer in layers:
|
||||
deepspeed_sparse_self_attn = BertSparseSelfAttention(config, sparsity_config)
|
||||
deepspeed_sparse_self_attn.query = layer.attention.self.query
|
||||
deepspeed_sparse_self_attn.key = layer.attention.self.key
|
||||
deepspeed_sparse_self_attn.value = layer.attention.self.value
|
||||
|
||||
layer.attention.self = deepspeed_sparse_self_attn
|
||||
|
||||
return layers
|
||||
|
||||
@staticmethod
|
||||
def pad_to_block_size(block_size,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
inputs_embeds,
|
||||
pad_token_id,
|
||||
model_mbeddings):
|
||||
"""This function pads input tokens and attention mask on sequence length dimension to be multiple of block size.
|
||||
This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
|
||||
It needs to be called in your model, such as BertModel, right before you calculate the embedding outputs.
|
||||
Note)
|
||||
1- instead of passing your embedding layer to this function, you can simply add this function to your model. It can be more simplified if given attention_mask and/or token_type_ids are none.
|
||||
2- you need to call unpdad function before returning your model output to unpad the encoder sequence output.
|
||||
|
||||
Arguments:
|
||||
block_size: required: an integer determining the block size of sparsity config.
|
||||
pad_token_id: required: an integer determining the pad token from the model config; such as bert.config.pad_token_id.
|
||||
input_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
|
||||
attention_mask: a torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
|
||||
token_type_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
||||
position_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the indices of positions of each input sequence tokens in the position embeddings.
|
||||
inputs_embeds: an optional torch.FloatTensor of shape [batch_size, sequence_length, hidden_size] that contains embedded representation and can be passed instead of input_ids directly.
|
||||
model_embeddings: an optional object. If inputs_embeds are not none, this will be your model embeddings such as BertEmbeddings from your model such as BertModel. You can move this function inside your model and use self.embeddings instead of passing this parameter.
|
||||
|
||||
Return:
|
||||
pad_len: an integer determining how much inputs have been padded to transfer sequence length dimension to multiple of block size.
|
||||
input_ids: if input_ids are not none padded input_ids otherwise none.
|
||||
attention_mask: if attention_mask is not none padded attention_mask otherwise none.
|
||||
token_type_ids: if token_type_ids are not none padded token_type_ids otherwise none.
|
||||
position_ids: if position_ids are not none padded position_ids otherwise none.
|
||||
inputs_embeds: if inputs_embeds are not none padded inputs_embeds otherwise none.
|
||||
"""
|
||||
|
||||
batch_size, seq_len = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
|
||||
|
||||
pad_len = (block_size - seq_len % block_size) % block_size
|
||||
if pad_len > 0:
|
||||
if inputs_embeds is not None:
|
||||
pad_input_ids = inputs_embeds.new_full((batch_size,
|
||||
pad_len),
|
||||
pad_token_id,
|
||||
dtype=torch.long)
|
||||
pad_inputs_embeds = model_embeddings(pad_input_ids)
|
||||
inputs_embeds = torch.cat([inputs_embeds, pad_inputs_embeds], dim=-2)
|
||||
# may not be needed as input_ids are not used if inputs_embeds are given
|
||||
if input_ids is not None:
|
||||
input_ids = F.pad(input_ids, (0, pad_len), value=pad_token_id)
|
||||
if position_ids is not None:
|
||||
# pad position_id with pad_token_id
|
||||
position_ids = F.pad(position_ids, (0, pad_len), value=pad_token_id)
|
||||
# pad attention mask without attention on the padding tokens
|
||||
attention_mask = F.pad(attention_mask, (0, pad_len), value=False)
|
||||
# pad token_type_ids with token_type_id = 0
|
||||
token_type_ids = F.pad(token_type_ids, (0, pad_len), value=0)
|
||||
|
||||
return pad_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
|
||||
|
||||
@staticmethod
|
||||
def unpad_sequence_output(pad_len, sequence_output):
|
||||
"""This function unpads sequence output if inputs of the model were padded.
|
||||
This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
|
||||
It needs to be called in your model, such as BertModel, right before you return the model outputs.
|
||||
|
||||
Arguments:
|
||||
pad_len: required: an integer determining how much model inputs have been padded to transfer sequence length dimension to multiple of block size.
|
||||
sequence_output: required: sequence output of the encoder layer.
|
||||
|
||||
Return:
|
||||
sequence_output: unpaded sequence output of the encoder layer.
|
||||
"""
|
||||
|
||||
if (pad_len > 0):
|
||||
sequence_output = sequence_output[:, :-pad_len]
|
||||
return sequence_output
|
142
deepspeed/ops/sparse_attention/sparse_self_attention.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import *
|
||||
import torch
|
||||
from collections import namedtuple
|
||||
from deepspeed.ops.sparse_attention import MatMul, Softmax, SparsityConfig
|
||||
import sys
|
||||
|
||||
|
||||
class SparseSelfAttention(nn.Module):
|
||||
"""Implements an efficient Sparse Self Attention of Transformer layer based on `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
|
||||
|
||||
For more information please see, TODO DeepSpeed Sparse Transformer.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
# SparsityConfig parameters needs to be set accordingly
|
||||
sparsity_config=SparsityConfig(num_heads=4),
|
||||
key_padding_mask_mode='add',
|
||||
attn_mask_mode='mul'):
|
||||
"""Initialize the sparse self attention layer.
|
||||
Arguments:
|
||||
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class.
|
||||
key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
|
||||
attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# sparsity information
|
||||
self.sparsity_config = sparsity_config
|
||||
|
||||
# mask modes
|
||||
self.key_padding_mask_mode = key_padding_mask_mode
|
||||
self.attn_mask_mode = attn_mask_mode
|
||||
|
||||
ops = dict()
|
||||
|
||||
# add to cache
|
||||
def get_ops(self, H, L):
|
||||
import sys
|
||||
if L not in SparseSelfAttention.ops:
|
||||
sparsity_layout = self.sparsity_config.make_layout(L)
|
||||
sparse_dot_sdd_nt = MatMul(sparsity_layout,
|
||||
self.sparsity_config.block,
|
||||
'sdd',
|
||||
trans_a=False,
|
||||
trans_b=True)
|
||||
|
||||
sparse_dot_dsd_nn = MatMul(sparsity_layout,
|
||||
self.sparsity_config.block,
|
||||
'dsd',
|
||||
trans_a=False,
|
||||
trans_b=False)
|
||||
|
||||
sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
|
||||
|
||||
SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt,
|
||||
sparse_dot_dsd_nn,
|
||||
sparse_softmax)
|
||||
return SparseSelfAttention.ops[L]
|
||||
|
||||
def transpose_key_for_scores(self, x, L):
|
||||
bsz, num_heads, seq_len, head_dim = x.size()
|
||||
if seq_len != L:
|
||||
return x.permute(0, 1, 3, 2)
|
||||
return x
|
||||
|
||||
def transpose_mask_for_sparse(self, qtype, x, is_key_padding_mask=False):
|
||||
x = x.type(qtype)
|
||||
if is_key_padding_mask:
|
||||
xdim = x.dim()
|
||||
for d in range(xdim - 1, 0, -1):
|
||||
x = x.squeeze(dim=d)
|
||||
return x
|
||||
return x.squeeze()
|
||||
|
||||
# forward pass
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
rpe=None,
|
||||
key_padding_mask=None,
|
||||
attn_mask=None):
|
||||
"""Applies forward phase of sparse self attention
|
||||
|
||||
Arguments:
|
||||
query: required: query tensor
|
||||
key: required: key tensor
|
||||
value: required: value tensor
|
||||
rpe: optional: a tensor same dimension as x that is used as relative position embedding
|
||||
key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
|
||||
attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
|
||||
key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
|
||||
attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
|
||||
|
||||
Return:
|
||||
attn_output: a dense tensor containing attnetion context
|
||||
"""
|
||||
bsz, num_heads, tgt_len, head_dim = query.size()
|
||||
|
||||
# transpose back key if it is already transposed
|
||||
key = self.transpose_key_for_scores(key, tgt_len)
|
||||
|
||||
# check that operation is supported
|
||||
if query.shape != key.shape or key.shape != value.shape:
|
||||
raise NotImplementedError('only self-attention is supported for now')
|
||||
|
||||
# squeeze key_padding_mask if it is given
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = self.transpose_mask_for_sparse(query.dtype,
|
||||
key_padding_mask,
|
||||
is_key_padding_mask=True)
|
||||
|
||||
# squeeze attn_mask if it is given
|
||||
if attn_mask is not None:
|
||||
attn_mask = self.transpose_mask_for_sparse(query.dtype, attn_mask)
|
||||
|
||||
# cache look-up table computations etc
|
||||
sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops(num_heads, tgt_len)
|
||||
|
||||
scaling = float(head_dim)**-0.5
|
||||
|
||||
# attention scores
|
||||
attn_output_weights = sparse_dot_sdd_nt(query, key)
|
||||
attn_output_weights = sparse_softmax(
|
||||
attn_output_weights,
|
||||
scale=scaling,
|
||||
rpe=rpe,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask_mode=self.key_padding_mask_mode,
|
||||
attn_mask_mode=self.attn_mask_mode)
|
||||
|
||||
# outputs
|
||||
attn_output = sparse_dot_dsd_nn(attn_output_weights, value)
|
||||
return attn_output
|
663
deepspeed/ops/sparse_attention/sparsity_config.py
Normal file
@ -0,0 +1,663 @@
|
||||
"""
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
"""
|
||||
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
class SparsityConfig:
|
||||
"""Abstract Configuration class to store `sparsity configuration of a self attention layer`.
|
||||
It contains shared property of different block-sparse sparsity patterns. However, each class needs to extend it based on required property and functionality.
|
||||
"""
|
||||
def __init__(self, num_heads, block=16, different_layout_per_head=False):
|
||||
"""Initialize the Sparsity Pattern Config.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
|
||||
"""
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.block = block
|
||||
self.different_layout_per_head = different_layout_per_head
|
||||
self.num_layout_heads = num_heads if different_layout_per_head else 1
|
||||
|
||||
def setup_layout(self, seq_len):
|
||||
"""Create layout tensor for the given sequence length
|
||||
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout of all head; initialized with zero
|
||||
"""
|
||||
|
||||
if (seq_len % self.block != 0):
|
||||
raise ValueError(
|
||||
f'Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!'
|
||||
)
|
||||
num_blocks = seq_len // self.block
|
||||
# TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
|
||||
layout = torch.zeros((self.num_heads, num_blocks, num_blocks), dtype=torch.int64)
|
||||
return layout
|
||||
|
||||
def check_and_propagate_first_head_layout(self, layout):
|
||||
"""If all heads require same sparsity layout, it propagate first head layout to all heads
|
||||
|
||||
Arguments:
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head
|
||||
"""
|
||||
|
||||
if not self.different_layout_per_head:
|
||||
layout[1:self.num_heads, :, :] = layout[0, :, :]
|
||||
return layout
|
||||
|
||||
|
||||
class DenseSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `Dense` configuration.
|
||||
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
|
||||
"""
|
||||
def __init__(self, num_heads, block=16, different_layout_per_head=False):
|
||||
"""Initialize the Dense Sparsity Pattern Config.
|
||||
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
|
||||
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
different_layout_per_head: optional: this is just for the sake of consistency with other sparsity formats; can ignore it for DenseSparsityConfig
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block, different_layout_per_head)
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
|
||||
|
||||
Arguments:
|
||||
seq_len: required: an integer determining the underling sequence length; must be <= max sequence length
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; for dense everything is 1
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
layout[:, :, :] = 1
|
||||
return layout
|
||||
|
||||
|
||||
class FixedSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `Fixed` sparsity configuration.
|
||||
For more details about this sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_heads,
|
||||
block=16,
|
||||
different_layout_per_head=False,
|
||||
num_local_blocks=4,
|
||||
num_global_blocks=1,
|
||||
attention='bidirectional',
|
||||
horizontal_global_attention=False,
|
||||
num_different_global_patterns=1):
|
||||
"""Initialize `Fixed` Sparsity Pattern Config.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
|
||||
num_local_blocks: optional: an integer determining the number of blocks in local attention window.
|
||||
num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention.
|
||||
attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
|
||||
horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
|
||||
num_different_global_patterns: optional: an integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative. For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different versions in which the first, Second, third, or forth block of each local window can be global representative of that window. This parameter determines how many of such patterns we want. Of course, there is a limitation based on num_local_blocks and num_global_blocks.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block, different_layout_per_head)
|
||||
|
||||
self.num_local_blocks = num_local_blocks
|
||||
|
||||
if (num_local_blocks % num_global_blocks != 0):
|
||||
raise ValueError(
|
||||
f'Number of blocks in a local window, {num_local_blocks}, must be dividable by number of global blocks, {num_global_blocks}!'
|
||||
)
|
||||
self.num_global_blocks = num_global_blocks
|
||||
|
||||
if (attention != 'unidirectional' and attention != 'bidirectional'):
|
||||
raise NotImplementedError(
|
||||
'only \"uni/bi-directional\" attentions are supported for now!')
|
||||
self.attention = attention
|
||||
|
||||
if (attention != 'bidirectional' and horizontal_global_attention):
|
||||
raise ValueError(
|
||||
'only \"bi-directional\" attentions can support horizontal global attention!'
|
||||
)
|
||||
self.horizontal_global_attention = horizontal_global_attention
|
||||
|
||||
if (num_different_global_patterns > 1 and not different_layout_per_head):
|
||||
raise ValueError(
|
||||
f'Number of different layouts cannot be more than one when you have set a single layout for all heads! Set different_layout_per_head to True.'
|
||||
)
|
||||
if (num_different_global_patterns > (num_local_blocks // num_global_blocks)):
|
||||
raise ValueError(
|
||||
f'Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, cannot be larger than number of local window blocks divided by number of global blocks, {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!'
|
||||
)
|
||||
self.num_different_global_patterns = num_different_global_patterns
|
||||
|
||||
def set_local_layout(self, h, layout):
|
||||
"""Sets local attantion layout used by the given head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
for i in range(0, num_blocks, self.num_local_blocks):
|
||||
end = min(i + self.num_local_blocks, num_blocks)
|
||||
for row in range(i, end):
|
||||
for col in range(
|
||||
i,
|
||||
(row + 1 if self.attention == 'unidirectional' else end)):
|
||||
layout[h, row, col] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout(self, h, layout):
|
||||
"""Sets global attantion layout used by the given head in the sparse attention.
|
||||
|
||||
Currently we set global blocks starting from the last block of a local window to the first one. That means if a local window consists of 4 blocks and global attention size is one block, we use block #4 in each local window as global. If we have different layout per head, then other heads will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global attentions, multiple head may have same global attentions.
|
||||
Note) if horizontal_global_attention is set, global blocks will be set both horizontally and vertically.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
first_global_block_idx = self.num_local_blocks - (
|
||||
1 + h % self.num_different_global_patterns) * self.num_global_blocks
|
||||
|
||||
# set all global blocks except the last one if (in last local window)
|
||||
end = num_blocks - (num_blocks % self.num_local_blocks)
|
||||
for i in range(first_global_block_idx, end, self.num_local_blocks):
|
||||
|
||||
# vertical global attention
|
||||
first_row = 0 if self.attention == 'bidirectional' else i
|
||||
#(((i // self.num_local_blocks) + 1) * self.num_local_blocks)
|
||||
#if (first_row < num_blocks):
|
||||
layout[h, first_row:, i:i + self.num_global_blocks] = 1
|
||||
|
||||
# horizontal global attention; only in bidirectional attention
|
||||
if (self.horizontal_global_attention):
|
||||
layout[h, i:i + self.num_global_blocks, :] = 1
|
||||
|
||||
# set last global blocks; handle possible short last local window
|
||||
if (end < num_blocks):
|
||||
start = min(end + first_global_block_idx,
|
||||
num_blocks - self.num_global_blocks)
|
||||
end = start + self.num_global_blocks
|
||||
|
||||
# vertical global attention
|
||||
first_row = 0 if self.attention == 'bidirectional' else start
|
||||
#(((start // self.num_local_blocks) + 1) * self.num_local_blocks)
|
||||
#if (first_row < num_blocks):
|
||||
layout[h, first_row:, start:end] = 1
|
||||
|
||||
# horizontal global attention
|
||||
if (self.horizontal_global_attention):
|
||||
layout[h, start:end, :] = 1
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates `Fixed` sparsity layout used by each head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_local_layout(h, layout)
|
||||
layout = self.set_global_layout(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
||||
|
||||
|
||||
class VariableSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `Variable` sparsity configuration.
|
||||
This layout is an extension of FixedSparsityConfig in which:
|
||||
- user can set random layout; default value is zero means no random block
|
||||
- user can provide a list of local block sizes
|
||||
- user can provide a list of global block indices.
|
||||
|
||||
For more details about `Fixed` sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_heads,
|
||||
block=16,
|
||||
different_layout_per_head=False,
|
||||
num_random_blocks=0,
|
||||
local_window_blocks=[4],
|
||||
global_block_indices=[0],
|
||||
global_block_end_indices=None,
|
||||
attention='bidirectional',
|
||||
horizontal_global_attention=False):
|
||||
"""Initialize `Variable` Sparsity Pattern Config.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. Currently this sparsity config can only assign single layout to all heads; needs to be extended for different layout per head.
|
||||
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
|
||||
local_window_blocks: optional: a list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows.
|
||||
global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window.
|
||||
global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention.
|
||||
num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention.
|
||||
attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
|
||||
horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block, different_layout_per_head)
|
||||
|
||||
self.num_random_blocks = num_random_blocks
|
||||
self.local_window_blocks = local_window_blocks
|
||||
self.global_block_indices = global_block_indices
|
||||
|
||||
if (global_block_end_indices is not None):
|
||||
if (len(global_block_indices) != len(global_block_end_indices)):
|
||||
raise ValueError(
|
||||
f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!'
|
||||
)
|
||||
for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)):
|
||||
if start_idx >= end_idx:
|
||||
raise ValueError(
|
||||
f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!'
|
||||
)
|
||||
self.global_block_end_indices = global_block_end_indices
|
||||
|
||||
if (attention != 'unidirectional' and attention != 'bidirectional'):
|
||||
raise NotImplementedError(
|
||||
'only \"uni/bi-directional\" attentions are supported for now!')
|
||||
self.attention = attention
|
||||
|
||||
if (attention != 'bidirectional' and horizontal_global_attention):
|
||||
raise ValueError(
|
||||
'only \"bi-directional\" attentions can support horizontal global attention!'
|
||||
)
|
||||
self.horizontal_global_attention = horizontal_global_attention
|
||||
|
||||
def set_random_layout(self, h, layout):
|
||||
"""Sets random attantion layout used by the given head in the sparse attention.
|
||||
Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if (num_blocks < self.num_random_blocks):
|
||||
raise ValueError(
|
||||
f'Number of random blocks, {self.num_random_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
|
||||
)
|
||||
for row in range(0, num_blocks):
|
||||
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
|
||||
layout[h, row, rnd_cols] = 1
|
||||
return layout
|
||||
|
||||
def set_local_layout(self, h, layout):
|
||||
"""Sets local attantion layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
start_block_idx = 0
|
||||
end_block_idx = 0
|
||||
for block_size in self.local_window_blocks:
|
||||
end_block_idx += block_size
|
||||
end_block_idx = min(end_block_idx, num_blocks)
|
||||
for row in range(start_block_idx, end_block_idx):
|
||||
for col in range(
|
||||
start_block_idx,
|
||||
(row + 1 if self.attention == 'unidirectional' else end_block_idx)):
|
||||
layout[h, row, col] = 1
|
||||
start_block_idx += block_size
|
||||
|
||||
# if there is any remaining not attended part, use the lats local window block size as local window for the remaining applicable local windows
|
||||
for i in range(start_block_idx, num_blocks, block_size):
|
||||
end_block_idx = min(i + block_size, num_blocks)
|
||||
for row in range(i, end_block_idx):
|
||||
for col in range(
|
||||
i,
|
||||
(row + 1 if self.attention == 'unidirectional' else end_block_idx)):
|
||||
layout[h, row, col] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout(self, h, layout):
|
||||
"""Sets global attantion layout used by the given head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if (self.global_block_end_indices is None):
|
||||
for idx in self.global_block_indices:
|
||||
# if global block idx is in the range of the sequnce blocks
|
||||
if (idx < num_blocks):
|
||||
#global rows
|
||||
if (self.horizontal_global_attention):
|
||||
layout[h, idx, :] = 1
|
||||
|
||||
#global columns
|
||||
first_row = 0 if self.attention == 'bidirectional' else idx
|
||||
layout[h, first_row:, idx] = 1
|
||||
else:
|
||||
for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)):
|
||||
# if global block idx is in the range of the sequnce blocks
|
||||
if (start_idx < num_blocks):
|
||||
end_idx = min(end_idx, num_blocks)
|
||||
#global rows
|
||||
if (self.horizontal_global_attention):
|
||||
layout[h, start_idx:end_idx, :] = 1
|
||||
|
||||
#global columns
|
||||
first_row = 0 if self.attention == 'bidirectional' else start_idx
|
||||
layout[h, first_row:, start_idx:end_idx] = 1
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates `Variable` sparsity layout used by each head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_random_layout(h, layout)
|
||||
layout = self.set_local_layout(h, layout)
|
||||
layout = self.set_global_layout(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
||||
|
||||
|
||||
class BigBirdSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `BigBird` sparsity configuration.
|
||||
For more details about this sparsity config, please see `Big Bird: Transformers for Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_heads,
|
||||
block=16,
|
||||
different_layout_per_head=False,
|
||||
num_random_blocks=1,
|
||||
num_sliding_window_blocks=3,
|
||||
num_global_blocks=1):
|
||||
"""Initialize the BigBird Sparsity Pattern Config.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
|
||||
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
|
||||
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window.
|
||||
num_global_blocks: optional: an integer determining how many consecutive blocks, starting from index 0, are considered as global attention. Global block tokens will be attended by all other block tokens and will attend to all other block tokens as well.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block, different_layout_per_head)
|
||||
|
||||
self.num_random_blocks = num_random_blocks
|
||||
self.num_sliding_window_blocks = num_sliding_window_blocks
|
||||
self.num_global_blocks = num_global_blocks
|
||||
|
||||
def set_random_layout(self, h, layout):
|
||||
"""Sets random attantion layout used by the given head in the sparse attention.
|
||||
Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if (num_blocks < self.num_random_blocks):
|
||||
raise ValueError(
|
||||
f'Number of random blocks, {self.num_random_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
|
||||
)
|
||||
|
||||
for row in range(0, num_blocks):
|
||||
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
|
||||
layout[h, row, rnd_cols] = 1
|
||||
return layout
|
||||
|
||||
def set_sliding_window_layout(self, h, layout):
|
||||
"""Sets sliding local attantion layout used by the given head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if (num_blocks < self.num_sliding_window_blocks):
|
||||
raise ValueError(
|
||||
f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
|
||||
)
|
||||
|
||||
w = self.num_sliding_window_blocks // 2
|
||||
for row in range(0, num_blocks):
|
||||
start = max(0, row - w)
|
||||
end = min(row + w + 1, num_blocks)
|
||||
layout[h, row, start:end] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout_itc(self, h, layout):
|
||||
"""Sets global attantion layout used by the given head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if (num_blocks < self.num_global_blocks):
|
||||
raise ValueError(
|
||||
f'Number of global blocks, {self.num_global_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
|
||||
)
|
||||
|
||||
#global rows
|
||||
layout[h, 0:self.num_global_blocks, :] = 1
|
||||
|
||||
#global columns
|
||||
layout[h, :, 0:self.num_global_blocks] = 1
|
||||
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates `BigBird` sparsity layout used by each head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_random_layout(h, layout)
|
||||
layout = self.set_sliding_window_layout(h, layout)
|
||||
layout = self.set_global_layout_itc(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
||||
|
||||
|
||||
class BSLongformerSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store edited `Longformer` sparsity configuration.
|
||||
|
||||
Note) this is a block-sparse version of the Longformer which is slightly different than original Longformer; which is element-wise sparsity.
|
||||
|
||||
For more details about this sparsity config, please see `Longformer: The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_heads,
|
||||
block=16,
|
||||
different_layout_per_head=False,
|
||||
num_sliding_window_blocks=3,
|
||||
global_block_indices=[0],
|
||||
global_block_end_indices=None):
|
||||
"""Initialize the edited `Longformer` Sparsity Pattern Config.
|
||||
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
|
||||
|
||||
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window.
|
||||
global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window.
|
||||
global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block, different_layout_per_head)
|
||||
|
||||
self.num_sliding_window_blocks = num_sliding_window_blocks
|
||||
self.global_block_indices = global_block_indices
|
||||
|
||||
if (global_block_end_indices is not None):
|
||||
if (len(global_block_indices) != len(global_block_end_indices)):
|
||||
raise ValueError(
|
||||
f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!'
|
||||
)
|
||||
for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)):
|
||||
if start_idx >= end_idx:
|
||||
raise ValueError(
|
||||
f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!'
|
||||
)
|
||||
self.global_block_end_indices = global_block_end_indices
|
||||
|
||||
def set_sliding_window_layout(self, h, layout):
|
||||
"""Sets sliding local attantion layout used by the given head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if (num_blocks < self.num_sliding_window_blocks):
|
||||
raise ValueError(
|
||||
f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
|
||||
)
|
||||
|
||||
w = self.num_sliding_window_blocks // 2
|
||||
for row in range(0, num_blocks):
|
||||
start = max(0, row - w)
|
||||
end = min(row + w + 1, num_blocks)
|
||||
layout[h, row, start:end] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout(self, h, layout):
|
||||
"""Sets global attantion layout used by the given head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if (self.global_block_end_indices is None):
|
||||
for idx in self.global_block_indices:
|
||||
# if global block idx is in the range of the sequnce blocks
|
||||
if (idx < num_blocks):
|
||||
#global rows
|
||||
layout[h, idx, :] = 1
|
||||
|
||||
#global columns
|
||||
layout[h, :, idx] = 1
|
||||
else:
|
||||
for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)):
|
||||
# if global block idx is in the range of the sequnce blocks
|
||||
if (start_idx < num_blocks):
|
||||
end_idx = min(end_idx, num_blocks)
|
||||
#global rows
|
||||
layout[h, start_idx:end_idx, :] = 1
|
||||
|
||||
#global columns
|
||||
layout[h, :, start_idx:end_idx] = 1
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
|
||||
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_sliding_window_layout(h, layout)
|
||||
layout = self.set_global_layout(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
32
deepspeed/ops/sparse_attention/trsrc/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def _build_file_index(directory, suffix='.tr'):
|
||||
"""Build an index of source files and their basenames in a given directory.
|
||||
|
||||
Args:
|
||||
directory (string): the directory to index
|
||||
suffix (string): index files with this suffix
|
||||
|
||||
Returns:
|
||||
list: A list of tuples of the form [(basename, absolute path), ...]
|
||||
"""
|
||||
|
||||
index = []
|
||||
|
||||
for fname in os.listdir(directory):
|
||||
if fname.endswith(suffix):
|
||||
basename = fname[:fname.rfind(suffix)] # strip the suffix
|
||||
path = os.path.join(directory, fname)
|
||||
index.append((basename, path))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Go over all local source files and parse them as strings
|
||||
_module = sys.modules[_build_file_index.__module__]
|
||||
_directory = os.path.dirname(os.path.realpath(__file__))
|
||||
for name, fname in _build_file_index(_directory):
|
||||
with open(fname, 'r') as fin:
|
||||
setattr(_module, name, fin.read())
|
201
deepspeed/ops/sparse_attention/trsrc/matmul.tr
Normal file
@ -0,0 +1,201 @@
|
||||
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
|
||||
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
|
||||
|
||||
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
|
||||
TYPE* B __readonly __noalias __aligned(16),
|
||||
TYPE* C __noalias __aligned(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8),
|
||||
long stride_za __multipleof(8),
|
||||
long stride_zb __multipleof(8),
|
||||
long stride_zc __multipleof(8),
|
||||
long stride_ha __multipleof(8),
|
||||
long stride_hb __multipleof(8),
|
||||
long stride_hc __multipleof(8),
|
||||
int DS0, int DS1,
|
||||
int SDD_K __multipleof(16),
|
||||
int SDD_off_width,
|
||||
int* lut, int* locks, int nlocks) {
|
||||
/* ---------------- */
|
||||
/* Prologue */
|
||||
/* ---------------- */
|
||||
// program ids
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int pidz = get_program_id(2);
|
||||
#ifdef SDD
|
||||
// load LUT header
|
||||
pid1 = pid1 + SDD_off_width;
|
||||
int blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
|
||||
int offlutn[TN] = blockidn*4;
|
||||
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
|
||||
int z = *(header + 0);
|
||||
int i[TM] = *(header + 1 + offlutm);
|
||||
int j[TN] = *(header + 2 + offlutn);
|
||||
int AS1 = SDD_K / TZ;
|
||||
int lockid = select(TZ > 1, 1, 0);
|
||||
int offka = pid0 * AS1;
|
||||
int offkb = pid0 * AS1;
|
||||
int offmc = 0;
|
||||
int offnc = 0;
|
||||
int offpa = 0;
|
||||
int offpb = 0;
|
||||
int maxid = TZ;
|
||||
int offhc = 0;
|
||||
int offha = z;
|
||||
int offhb = z;
|
||||
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
|
||||
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
|
||||
#else
|
||||
// load LUT header
|
||||
int *header = lut + pid0 * 6;
|
||||
int offset = *(header + 0);
|
||||
int AS1 = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int depth = *(header + 3);
|
||||
int lockid = *(header + 4);
|
||||
int maxid = *(header + 5);
|
||||
int *pinc = lut + offset;
|
||||
int offhc = depth;
|
||||
#ifdef DSD
|
||||
// output offset
|
||||
int offnc = pid1 * TN;
|
||||
int offmc = column * TM;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offnb = pid1 * TN;
|
||||
int offkb __multipleof(8) = *pinc;
|
||||
int offpb = 0;
|
||||
// sparse input offset
|
||||
int offma = 0;
|
||||
int offka = 0;
|
||||
long offpa __multipleof(8) = *(pinc + 1);
|
||||
offpa = offpa * BLOCK * BLOCK;
|
||||
int offha = 0;
|
||||
int offhb = depth;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
// output offset
|
||||
int offmc = pid1 * TM;
|
||||
int offnc = column * TN;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offma = pid1 * TM;
|
||||
int offka __multipleof(8) = *pinc;
|
||||
int offpa = 0;
|
||||
// sparse input offset
|
||||
int offnb = 0;
|
||||
int offkb = 0;
|
||||
long offpb __multipleof(8) = *(pinc + 1);
|
||||
offpb = offpb * BLOCK * BLOCK;
|
||||
int offha = depth;
|
||||
int offhb = 0;
|
||||
#endif
|
||||
int ram[TM] = offma + 0 ... TM;
|
||||
int rbn[TN] = offnb + 0 ... TN;
|
||||
#endif
|
||||
// initialize a, b pointers
|
||||
int rka[TK] = offka + 0 ... TK;
|
||||
int rkb[TK] = offkb + 0 ... TK;
|
||||
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
|
||||
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
||||
// pre-fetch
|
||||
#ifdef DDS
|
||||
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
|
||||
#else
|
||||
bool checkam[TM, TK] = AS1 > 0;
|
||||
#endif
|
||||
#ifdef DSD
|
||||
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
|
||||
#else
|
||||
bool checkbn[TK, TN] = AS1 > 0;
|
||||
#endif
|
||||
TYPE a[TM, TK] = checkam ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkbn ? *pb : 0;
|
||||
|
||||
/* ---------------- */
|
||||
/* Inner Loop */
|
||||
/* ---------------- */
|
||||
// create result tile
|
||||
float acc[TM, TN] = 0;
|
||||
int step = TK;
|
||||
for(int k = AS1; k > 0; k -= step) {
|
||||
acc += a @ b;
|
||||
// update pointers
|
||||
#ifdef SDD
|
||||
int inc_a = TK * STRIDE_AK;
|
||||
int inc_b = TK * STRIDE_BK;
|
||||
#else
|
||||
pinc += 2;
|
||||
#ifdef DSD
|
||||
int inc_b __multipleof(8) = *pinc;
|
||||
int inc_a __multipleof(8) = *(pinc + 1);
|
||||
inc_b = inc_b * STRIDE_BK;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
int inc_a __multipleof(8) = *pinc;
|
||||
int inc_b __multipleof(8) = *(pinc + 1);
|
||||
inc_a = inc_a * STRIDE_AK;
|
||||
#endif
|
||||
#endif
|
||||
pa += inc_a;
|
||||
pb += inc_b;
|
||||
// pre-fetch
|
||||
bool checkak[TM, TK] = k > TK;
|
||||
bool checkbk[TK, TN] = k > TK;
|
||||
bool checka[TM, TK] = checkam && checkak;
|
||||
bool checkb[TK, TN] = checkbk && checkbn;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
/* ---------------- */
|
||||
/* Epilogue */
|
||||
/* ---------------- */
|
||||
// initialize c pointers
|
||||
#ifdef SDD
|
||||
bool checkc[TM, TN] = 1;
|
||||
// rematerialize
|
||||
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
|
||||
int rr_offlutn[TN] = rr_blockidn*4;
|
||||
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
|
||||
int bkid[TM, TN] = *(header + off_bkid);
|
||||
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
|
||||
// range within blocks
|
||||
int rcm[TM] = (0 ... TM) % BLOCK;
|
||||
int rcn[TN] = (0 ... TN) % BLOCK;
|
||||
#else
|
||||
int rcm[TM] = offmc + 0 ... TM;
|
||||
int rcn[TN] = offnc + 0 ... TN;
|
||||
#ifdef DSD
|
||||
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
|
||||
#endif
|
||||
#endif
|
||||
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
|
||||
// write-back directly
|
||||
if(lockid == 0) {
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
// accumulate partial result using spin-locks
|
||||
else {
|
||||
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % maxid);
|
||||
atomic_xchg(plock, 0);
|
||||
}
|
||||
}
|
54
deepspeed/ops/sparse_attention/trsrc/softmax_bwd.tr
Normal file
@ -0,0 +1,54 @@
|
||||
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
|
||||
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
|
||||
|
||||
__global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16),
|
||||
float scale,
|
||||
TYPE* DX __readonly __noalias __aligned(16),
|
||||
int* LUT,
|
||||
int sizemax,
|
||||
long stride_zx __multipleof(BLOCK),
|
||||
long stride_zdx __multipleof(BLOCK)) {
|
||||
int pidhm = get_program_id(0);
|
||||
int pidz = get_program_id(1);
|
||||
|
||||
// create index ranges
|
||||
int rxm = pidhm % BLOCK;
|
||||
int rbm = pidhm / BLOCK;
|
||||
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||
|
||||
// extract information from look-up table
|
||||
int* header = LUT + rbm * 2;
|
||||
int size = *(header + 0);
|
||||
int offset = *(header + 1);
|
||||
|
||||
// bounds checking on lut
|
||||
bool check[TN] = rbn < size;
|
||||
int rbmn[TN] = check ? rbn : size - 1;
|
||||
|
||||
// initialize pointers to block-sparse input
|
||||
long blockid[TN] = *(LUT + offset + rbmn*4);
|
||||
|
||||
TYPE* px[TN] = X + pidz * stride_zx
|
||||
+ blockid * BLOCK * BLOCK
|
||||
+ rxm * BLOCK
|
||||
+ rxn;
|
||||
|
||||
TYPE* pdx[TN] = DX + pidz * stride_zdx
|
||||
+ blockid * BLOCK * BLOCK
|
||||
+ rxm * BLOCK
|
||||
+ rxn;
|
||||
|
||||
// compute fused softmax backward
|
||||
TYPE x[TN] = check ? *px : 0;
|
||||
TYPE dx[TN] = check ? *pdx : 0;
|
||||
float Fdx[TN] = dx;
|
||||
float Fx[TN] = x;
|
||||
float Fxdx[TN] = Fdx*Fx;
|
||||
float Fxdxsum = Fxdx[+];
|
||||
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
|
||||
TYPE y[TN] = Fy;
|
||||
|
||||
// write-back
|
||||
*? (check)pdx = y;
|
||||
}
|
136
deepspeed/ops/sparse_attention/trsrc/softmax_fwd.tr
Normal file
@ -0,0 +1,136 @@
|
||||
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
|
||||
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
|
||||
|
||||
__global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
|
||||
float scale,
|
||||
int *LUT __readonly __noalias __aligned(16),
|
||||
TYPE *RPE __readonly __noalias __aligned(16),
|
||||
TYPE *KP_M __readonly __noalias __aligned(16),
|
||||
TYPE *ATTN_M __readonly __noalias __aligned(16),
|
||||
int num_blocks,
|
||||
int sizemax,
|
||||
long stride_zx __multipleof(BLOCK),
|
||||
long stride_zrpe __multipleof(BLOCK),
|
||||
int stride_hrpe __multipleof(BLOCK),
|
||||
int stride_srpe __multipleof(BLOCK),
|
||||
int stride_zkpm __multipleof(BLOCK),
|
||||
int stride_zattnm __multipleof(BLOCK)){
|
||||
int pidhm = get_program_id(0);
|
||||
int pidz = get_program_id(1);
|
||||
|
||||
// create index ranges
|
||||
int rxm = pidhm % BLOCK;
|
||||
int rbm = pidhm / BLOCK;
|
||||
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||
|
||||
// extract information from look-up table
|
||||
int* header = LUT + rbm * 2;
|
||||
int size = *(header + 0);
|
||||
int offset = *(header + 1);
|
||||
|
||||
bool check[TN] = rbn < size;
|
||||
int rbmn[TN] = check ? rbn : size - 1;
|
||||
|
||||
// block id and column id
|
||||
long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
|
||||
long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
|
||||
long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
|
||||
long headid [TN] = *(LUT + offset + rbmn*4 + 3);
|
||||
|
||||
// pointers to X
|
||||
TYPE* px[TN] = X + pidz * stride_zx
|
||||
+ blockid * BLOCK * BLOCK
|
||||
+ rxm * BLOCK
|
||||
+ rxn;
|
||||
#ifdef APPLY_RPE
|
||||
// pointers to relative position embedding
|
||||
TYPE* prpe[TN] = RPE + pidz * stride_zrpe
|
||||
+ headid * stride_hrpe
|
||||
+ columnid * BLOCK
|
||||
+ rowid * BLOCK * stride_srpe
|
||||
+ rxm * stride_srpe
|
||||
+ rxn;
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_KP_MASK
|
||||
// pointers to key padding mask
|
||||
TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
|
||||
+ columnid * BLOCK
|
||||
+ rxn;
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
// pointers to attention mask
|
||||
TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
|
||||
+ rowid * BLOCK * stride_zattnm
|
||||
+ rxm * stride_zattnm
|
||||
+ rxn;
|
||||
#endif
|
||||
|
||||
// load input
|
||||
TYPE x[TN] = check ? *px : -INFINITY;
|
||||
|
||||
#ifdef APPLY_RPE
|
||||
// load relative position embedding
|
||||
TYPE rpe[TN] = check ? *prpe : 0;
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_KP_MASK
|
||||
// load key-padding mask
|
||||
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
// load attention mask
|
||||
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
|
||||
#endif
|
||||
|
||||
// compute softmax in float
|
||||
#ifdef APPLY_RPE
|
||||
float Frpe[TN] = rpe;
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_KP_MASK
|
||||
float Fkp_m[TN] = kp_m;
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
float Fattn_m[TN] = attn_m;
|
||||
#endif
|
||||
|
||||
#ifdef KP_MASK_MUL
|
||||
Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
|
||||
#endif
|
||||
|
||||
#ifdef ATTN_MASK_MUL
|
||||
Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
|
||||
#endif
|
||||
|
||||
float Fx[TN] = x;
|
||||
|
||||
#ifdef APPLY_SCALE
|
||||
Fx = Fx * scale; // apply scale
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_RPE
|
||||
Fx = Fx + Frpe; // apply relative position embedding
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_KP_MASK
|
||||
Fx = Fx + Fkp_m; // apply key padding mask
|
||||
#endif
|
||||
|
||||
#ifdef APPLY_ATTN_MASK
|
||||
Fx = Fx + Fattn_m; // apply attention mask
|
||||
#endif
|
||||
|
||||
float Fxmax = Fx[max];
|
||||
float Fy[TN] = exp(Fx - Fxmax);
|
||||
float Fysum = (check ? Fy : 0)[+];
|
||||
|
||||
// write-back in half/float
|
||||
TYPE y[TN] = Fy;
|
||||
TYPE ysum = Fysum;
|
||||
*?(check)px = y / ysum;
|
||||
}
|
1
deepspeed/ops/transformer/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from deepspeed.ops.transformer.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
|
36
deepspeed/pt/deepspeed_cuda.py → deepspeed/ops/transformer/transformer.py
Executable file → Normal file
@ -1,10 +1,16 @@
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
'''
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
'''
|
||||
import json
|
||||
import math
|
||||
import deepspeed_transformer_cuda as ds_transformer_cuda
|
||||
import deepspeed_stochastic_transformer_cuda as ds_stochastic_transformer_cuda
|
||||
import importlib
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
# Cuda modules will be imported if needed
|
||||
transformer_cuda_module = None
|
||||
stochastic_transformer_cuda_module = None
|
||||
|
||||
|
||||
class TransformerConfig():
|
||||
@ -159,7 +165,7 @@ class DeepSpeedTransformerFunction(Function):
|
||||
if bsz > config.batch_size:
|
||||
raise ValueError('Input batch size exceeds the limit.')
|
||||
|
||||
cuda_module = ds_stochastic_transformer_cuda if config.stochastic_mode else ds_transformer_cuda
|
||||
cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
|
||||
forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32
|
||||
|
||||
(output,
|
||||
@ -321,7 +327,7 @@ class DeepSpeedTransformerFunction(Function):
|
||||
norm_w,
|
||||
norm_b) = ctx.saved_tensors
|
||||
|
||||
cuda_module = ds_stochastic_transformer_cuda if ctx.config.stochastic_mode else ds_transformer_cuda
|
||||
cuda_module = stochastic_transformer_cuda_module if ctx.config.stochastic_mode else transformer_cuda_module
|
||||
backward_func = cuda_module.backward_fp16 if ctx.config.fp16 else cuda_module.backward_fp32
|
||||
|
||||
(grad_input,
|
||||
@ -457,8 +463,22 @@ class DeepSpeedTransformerLayer(nn.Module):
|
||||
self.norm_w = initial_weights[7]
|
||||
self.norm_b = initial_biases[7]
|
||||
|
||||
# Import cuda modules if needed
|
||||
global transformer_cuda_module, stochastic_transformer_cuda_module
|
||||
if transformer_cuda_module is None or stochastic_transformer_cuda_module is None:
|
||||
try:
|
||||
transformer_cuda_module = importlib.import_module(
|
||||
"deepspeed.ops.transformer.transformer_cuda")
|
||||
stochastic_transformer_cuda_module = importlib.import_module(
|
||||
"deepspeed.ops.transformer.stochastic_transformer_cuda")
|
||||
except ImportError as err:
|
||||
print(
|
||||
"Unable to import transformer cuda extension, please build DeepSpeed with cuda/cpp extensions."
|
||||
)
|
||||
raise err
|
||||
|
||||
# create the layer in cuda kernels.
|
||||
cuda_module = ds_stochastic_transformer_cuda if self.config.stochastic_mode else ds_transformer_cuda
|
||||
cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module
|
||||
create_layer_func = cuda_module.create_transformer_layer_fp16 if self.config.fp16 else cuda_module.create_transformer_layer_fp32
|
||||
|
||||
create_layer_func(self.config.layer_id,
|
0
deepspeed/runtime/__init__.py
Normal file
@ -13,16 +13,17 @@ b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
import contextlib
|
||||
import copy
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import contextlib
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
from deepspeed.pt.deepspeed_timer import SynchronizedWallClockTimer as Timers
|
||||
import torch.distributed as dist
|
||||
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
|
||||
from deepspeed.pt.log_utils import logger
|
||||
|
||||
from deepspeed.runtime.config import DeepSpeedConfig
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
|
||||
|
||||
#DeepSpeed Checkpointing Enabled or Disabled
|
||||
deepspeed_checkpointing_enabled = False
|
@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT license.
|
||||
"""
|
||||
|
||||
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
|
||||
from deepspeed.runtime.config_utils import get_scalar_param
|
||||
|
||||
#########################################
|
||||
# DeepSpeed Activation Checkpointing
|
@ -6,12 +6,12 @@ Licensed under the MIT license.
|
||||
import torch
|
||||
import json
|
||||
import copy
|
||||
from deepspeed.pt.deepspeed_constants import *
|
||||
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
|
||||
from deepspeed.pt.deepspeed_config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
|
||||
from deepspeed.pt.deepspeed_zero_config import DeepSpeedZeroConfig
|
||||
from deepspeed.pt.deepspeed_checkpointing_config import DeepSpeedActivationCheckpointingConfig
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.runtime.constants import *
|
||||
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
|
||||
from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
|
||||
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
|
||||
from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
|
||||
from deepspeed.utils import logger
|
||||
|
||||
TENSOR_CORE_ALIGN_SIZE = 8
|
||||
ADAM_OPTIMIZER = 'adam'
|
||||
@ -158,6 +158,177 @@ def get_gradient_clipping(param_dict):
|
||||
return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)
|
||||
|
||||
|
||||
def get_sparse_attention(param_dict):
|
||||
if SPARSE_ATTENTION in param_dict.keys():
|
||||
sparsity = param_dict[SPARSE_ATTENTION]
|
||||
mode = get_sparse_attention_mode(sparsity)
|
||||
|
||||
if (mode == SPARSE_DENSE_MODE):
|
||||
return get_sparse_dense_config(sparsity)
|
||||
elif (mode == SPARSE_FIXED_MODE):
|
||||
return get_sparse_fixed_config(sparsity)
|
||||
elif (mode == SPARSE_VARIABLE_MODE):
|
||||
return get_sparse_variable_config(sparsity)
|
||||
elif (mode == SPARSE_BIGBIRD_MODE):
|
||||
return get_sparse_bigbird_config(sparsity)
|
||||
elif (mode == SPARSE_BSLONGFORMER_MODE):
|
||||
return get_sparse_bslongformer_config(sparsity)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Given sparsity mode, {mode}, has not been implemented yet!')
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_sparse_dense_config(sparsity):
|
||||
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
|
||||
return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block}
|
||||
|
||||
|
||||
def get_sparse_fixed_config(sparsity):
|
||||
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
|
||||
different_layout_per_head = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
|
||||
num_local_blocks = get_scalar_param(sparsity,
|
||||
SPARSE_NUM_LOCAL_BLOCKS,
|
||||
SPARSE_NUM_LOCAL_BLOCKS_DEFAULT)
|
||||
num_global_blocks = get_scalar_param(sparsity,
|
||||
SPARSE_NUM_GLOBAL_BLOCKS,
|
||||
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
|
||||
attention = get_scalar_param(sparsity,
|
||||
SPARSE_ATTENTION_TYPE,
|
||||
SPARSE_ATTENTION_TYPE_DEFAULT)
|
||||
horizontal_global_attention = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT)
|
||||
num_differnt_global_patterns = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS,
|
||||
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT)
|
||||
|
||||
return {
|
||||
SPARSE_MODE: SPARSE_FIXED_MODE,
|
||||
SPARSE_BLOCK: block,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
|
||||
SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks,
|
||||
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
|
||||
SPARSE_ATTENTION_TYPE: attention,
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
|
||||
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_differnt_global_patterns
|
||||
}
|
||||
|
||||
|
||||
def get_sparse_variable_config(sparsity):
|
||||
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
|
||||
different_layout_per_head = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
|
||||
num_random_blocks = get_scalar_param(sparsity,
|
||||
SPARSE_NUM_RANDOM_BLOCKS,
|
||||
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
|
||||
local_window_blocks = get_scalar_param(sparsity,
|
||||
SPARSE_LOCAL_WINDOW_BLOCKS,
|
||||
SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
|
||||
global_block_indices = get_scalar_param(sparsity,
|
||||
SPARSE_GLOBAL_BLOCK_INDICES,
|
||||
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
|
||||
global_block_end_indices = get_scalar_param(sparsity,
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES,
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT)
|
||||
attention = get_scalar_param(sparsity,
|
||||
SPARSE_ATTENTION_TYPE,
|
||||
SPARSE_ATTENTION_TYPE_DEFAULT)
|
||||
horizontal_global_attention = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT)
|
||||
|
||||
return {
|
||||
SPARSE_MODE: SPARSE_VARIABLE_MODE,
|
||||
SPARSE_BLOCK: block,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
|
||||
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
|
||||
SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks,
|
||||
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
|
||||
SPARSE_ATTENTION_TYPE: attention,
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention
|
||||
}
|
||||
|
||||
|
||||
def get_sparse_bigbird_config(sparsity):
|
||||
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
|
||||
different_layout_per_head = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
|
||||
num_random_blocks = get_scalar_param(sparsity,
|
||||
SPARSE_NUM_RANDOM_BLOCKS,
|
||||
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
|
||||
num_sliding_window_blocks = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT)
|
||||
num_global_blocks = get_scalar_param(sparsity,
|
||||
SPARSE_NUM_GLOBAL_BLOCKS,
|
||||
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
|
||||
|
||||
return {
|
||||
SPARSE_MODE: SPARSE_BIGBIRD_MODE,
|
||||
SPARSE_BLOCK: block,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
|
||||
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
|
||||
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks
|
||||
}
|
||||
|
||||
|
||||
def get_sparse_bslongformer_config(sparsity):
|
||||
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
|
||||
different_layout_per_head = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
|
||||
num_sliding_window_blocks = get_scalar_param(
|
||||
sparsity,
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT)
|
||||
global_block_indices = get_scalar_param(sparsity,
|
||||
SPARSE_GLOBAL_BLOCK_INDICES,
|
||||
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
|
||||
global_block_end_indices = get_scalar_param(sparsity,
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES,
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT)
|
||||
|
||||
return {
|
||||
SPARSE_MODE: SPARSE_BSLONGFORMER_MODE,
|
||||
SPARSE_BLOCK: block,
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
|
||||
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices
|
||||
}
|
||||
|
||||
|
||||
def get_sparse_attention_mode(param_dict):
|
||||
if SPARSE_MODE in param_dict.keys():
|
||||
return param_dict[SPARSE_MODE]
|
||||
else:
|
||||
return SPARSE_MODE_DEFAULT
|
||||
|
||||
|
||||
def get_sparse_attention_type(param_dict):
|
||||
if SPARSE_ATTENTION_TYPE in param_dict.keys():
|
||||
return param_dict[SPARSE_ATTENTION_TYPE]
|
||||
else:
|
||||
return SPARSE_ATTENTION_TYPE_DEFAULT
|
||||
|
||||
|
||||
def get_optimizer_name(param_dict):
|
||||
if OPTIMIZER in param_dict.keys() and \
|
||||
TYPE in param_dict[OPTIMIZER].keys():
|
||||
@ -358,6 +529,8 @@ class DeepSpeedConfig(object):
|
||||
self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
|
||||
self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
|
||||
|
||||
self.sparse_attention = get_sparse_attention(param_dict)
|
||||
|
||||
def _batch_assertion(self):
|
||||
|
||||
train_batch = self.train_batch_size
|
@ -17,6 +17,42 @@ ROUTE_ENCODE = "encode"
|
||||
TRAIN_BATCH_SIZE = "train_batch_size"
|
||||
TRAIN_BATCH_SIZE_DEFAULT = None
|
||||
|
||||
#############################################
|
||||
# Sparse attention
|
||||
#############################################
|
||||
SPARSE_ATTENTION = "sparse_attention"
|
||||
SPARSE_DENSE_MODE = "dense"
|
||||
SPARSE_FIXED_MODE = "fixed"
|
||||
SPARSE_VARIABLE_MODE = "variable"
|
||||
SPARSE_BIGBIRD_MODE = "bigbird"
|
||||
SPARSE_BSLONGFORMER_MODE = "bslongformer"
|
||||
SPARSE_MODE = "mode"
|
||||
SPARSE_MODE_DEFAULT = SPARSE_FIXED_MODE
|
||||
SPARSE_BLOCK = "block"
|
||||
SPARSE_BLOCK_DEFAULT = 16
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD = "different_layout_per_head"
|
||||
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT = False
|
||||
SPARSE_NUM_LOCAL_BLOCKS = "num_local_blocks"
|
||||
SPARSE_NUM_LOCAL_BLOCKS_DEFAULT = 4
|
||||
SPARSE_NUM_GLOBAL_BLOCKS = "num_global_blocks"
|
||||
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT = 1
|
||||
SPARSE_ATTENTION_TYPE = "attention"
|
||||
SPARSE_ATTENTION_TYPE_DEFAULT = "bidirectional"
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION = "horizontal_global_attention"
|
||||
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT = False
|
||||
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS = "num_differnt_global_patterns"
|
||||
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT = 1
|
||||
SPARSE_NUM_RANDOM_BLOCKS = "num_random_blocks"
|
||||
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT = 0
|
||||
SPARSE_LOCAL_WINDOW_BLOCKS = "local_window_blocks"
|
||||
SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT = [4]
|
||||
SPARSE_GLOBAL_BLOCK_INDICES = "global_block_indices"
|
||||
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT = [0]
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES = "global_block_end_indices"
|
||||
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT = None
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS = "num_sliding_window_blocks"
|
||||
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT = 3
|
||||
|
||||
#############################################
|
||||
# Optimizer and lr scheduler
|
||||
#############################################
|
@ -2,36 +2,35 @@
|
||||
Copyright 2019 The Microsoft DeepSpeed Team
|
||||
'''
|
||||
|
||||
import torch
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
import torch.distributed as dist
|
||||
|
||||
from apex import amp
|
||||
from torch.nn.modules import Module
|
||||
from torch.distributed.distributed_c10d import _get_global_rank
|
||||
from apex import amp
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer
|
||||
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
|
||||
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
|
||||
from deepspeed.pt.log_utils import logger
|
||||
import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing
|
||||
|
||||
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
|
||||
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.pt.deepspeed_fused_lamb import FusedLamb
|
||||
from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
|
||||
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
|
||||
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
|
||||
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
|
||||
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
|
||||
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.runtime.config import DeepSpeedConfig, \
|
||||
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
|
||||
|
||||
from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
|
||||
from deepspeed.pt.deepspeed_constants import \
|
||||
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
|
||||
from deepspeed.runtime.constants import \
|
||||
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
|
||||
TORCH_DISTRIBUTED_DEFAULT_PORT, \
|
||||
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
|
||||
from deepspeed.runtime.csr_tensor import CSRTensor
|
||||
import deepspeed.runtime.lr_schedules as lr_schedules
|
||||
|
||||
import deepspeed.pt.deepspeed_lr_schedules as lr_schedules
|
||||
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
|
||||
from deepspeed.ops.lamb import FusedLamb
|
||||
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
|
||||
|
||||
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
|
||||
SUMMARY_WRITER_DIR_NAME = "JobId"
|
||||
@ -92,7 +91,7 @@ def print_configuration(args, name):
|
||||
logger.info(' {} {} {}'.format(arg, dots, getattr(args, arg)))
|
||||
|
||||
|
||||
class DeepSpeedLight(Module):
|
||||
class DeepSpeedEngine(Module):
|
||||
r"""DeepSpeed engine for training.
|
||||
"""
|
||||
def __init__(self,
|
||||
@ -106,7 +105,7 @@ class DeepSpeedLight(Module):
|
||||
dist_init_required=None,
|
||||
collate_fn=None,
|
||||
config_params=None):
|
||||
super(DeepSpeedLight, self).__init__()
|
||||
super(DeepSpeedEngine, self).__init__()
|
||||
|
||||
self.client_optimizer = optimizer
|
||||
self.client_model_parameters = model_parameters
|
0
deepspeed/runtime/fp16/__init__.py
Normal file
@ -9,9 +9,9 @@ import torch
|
||||
import math
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
|
||||
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
|
||||
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
class FP16_Optimizer(object):
|
@ -9,9 +9,9 @@ import torch
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
import math
|
||||
|
||||
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
|
||||
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
|
||||
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
class FP16_UnfusedOptimizer(object):
|
@ -12,8 +12,8 @@ import argparse
|
||||
from torch.optim import Optimizer
|
||||
from typing import Union, List
|
||||
import math
|
||||
from deepspeed.pt.deepspeed_constants import *
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.runtime.constants import *
|
||||
from deepspeed.utils import logger
|
||||
|
||||
LR_SCHEDULE = 'lr_schedule'
|
||||
LR_RANGE_TEST = 'LRRangeTest'
|
@ -9,7 +9,7 @@ Helper functions and classes from multiple sources.
|
||||
import torch
|
||||
from torch._six import inf
|
||||
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
class CheckOverflow(object):
|
0
deepspeed/runtime/zero/__init__.py
Normal file
@ -3,9 +3,8 @@ Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT license.
|
||||
"""
|
||||
|
||||
#from deepspeed.pt.deepspeed_constants import *
|
||||
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.runtime.config_utils import get_scalar_param
|
||||
from deepspeed.utils import logger
|
||||
|
||||
#########################################
|
||||
# ZeRO optimization
|
@ -4,11 +4,11 @@ import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from collections import defaultdict
|
||||
|
||||
from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups
|
||||
from deepspeed.pt.log_utils import log_dist, logger
|
||||
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
|
||||
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
|
||||
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
|
||||
from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups
|
||||
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
|
||||
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow
|
||||
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
|
||||
from deepspeed.utils import logger, log_dist
|
||||
|
||||
|
||||
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
|
@ -10,16 +10,15 @@ import math
|
||||
from torch._six import inf
|
||||
from torch.autograd import Variable
|
||||
|
||||
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
|
||||
from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter
|
||||
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS
|
||||
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
|
||||
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
|
||||
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
|
||||
from deepspeed.utils import logger
|
||||
|
||||
#Toggle this to true to enable correctness test
|
||||
#with gradient partitioning and without
|
||||
pg_correctness_test = False
|
||||
|
||||
from deepspeed.pt.log_utils import logger
|
||||
|
||||
try:
|
||||
from apex_C import flatten
|
||||
from apex_C import unflatten
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
|
1
deepspeed/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from deepspeed.utils.logging import logger, log_dist
|
@ -6,7 +6,7 @@ import time
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from deepspeed.pt.log_utils import logger
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
def print_rank_0(message):
|
@ -335,3 +335,43 @@ Enabling and configure ZeRO memory optimizations
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------ | ------- |
|
||||
| Logs the forward and backward time for each checkpoint function | `false` |
|
||||
|
||||
### Sparse Attention
|
||||
|
||||
***sparse\_attention***: [dictionary]
|
||||
|
||||
| Fields | Value | Example |
|
||||
| ------ | ------------------------------------------------------------ | ------------------------------ |
|
||||
| mode | A string determining sparsity structure type. Deepspeed currently supports `"dense"`, `"fixed"`, `"bigbird"`, `"bslongformer"`, and `"variable"`. | `"fixed"` |
|
||||
| block | An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. | 16 |
|
||||
| different\_layout\_per\_head | A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability. | false |
|
||||
| num\_local\_blocks | An integer determining the number of random blocks in each block row; only used in `"fixed"` mode. | 4 |
|
||||
| num\_global\_blocks | An integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; used in `"fixed"` and `"bigbird"` modes. | 1 |
|
||||
| attention | A string determining attention type. Attention can be `"unidirectional"`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty. Or it can be `"bidirectional"`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular; used in `"fixed"` and `"variable"` modes. | `"bidirectional"` |
|
||||
| horizontal\_global\_attention | A boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `"bidirectional"`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks; used in `"fixed"` and `"variable"` modes. | false |
|
||||
| num\_different\_global\_patterns | An integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative; used only in `"fixed"` mode. | 4 |
|
||||
| num\_random\_blocks | An integer determining the number of random blocks in each block row; used in `"variable"` and `"bigbird"` modes. | 0 |
|
||||
| local\_window\_blocks | A list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows; only used in `"variable"` mode. | [4] |
|
||||
| global\_block\_indices | A list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if global\_block\_end\_indices parameter is set, this parameter is used as starting index of each global window; used in `"variable"` and `"bslongformer"` modes. | [0] |
|
||||
| global\_block\_end\_indices | A list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global\_block\_indices parameter, and combining this two parameters, for each index i, blocks from global\_block\_indices[i] to global\_block\_end\_indices[i], exclusive, are considered as global attention; used in `"variable"` and `"bslongformer"` modes. | None |
|
||||
| num\_sliding\_window\_blocks | An integer determining the number of blocks in sliding local attention window; used in `"bigbird"` and `"bslongformer"` modes. | 3 |
|
||||
|
||||
Example of ***sparse\_attention***
|
||||
|
||||
```json
|
||||
"sparse_attention": {
|
||||
"mode": "fixed",
|
||||
"block": 16,
|
||||
"different_layout_per_head": true,
|
||||
"num_local_blocks": 4,
|
||||
"num_global_blocks": 1,
|
||||
"attention": "bidirectional",
|
||||
"horizontal_global_attention": false,
|
||||
"num_different_global_patterns": 4,
|
||||
"num_random_blocks": 0,
|
||||
"local_window_blocks": [4],
|
||||
"global_block_indices": [0],
|
||||
"global_block_end_indices": None,
|
||||
"num_sliding_window_blocks": 3
|
||||
}
|
||||
```
|
||||
|
83
docs/_posts/2020-09-09-sparse-attention.md
Normal file
@ -0,0 +1,83 @@
|
||||
---
|
||||
layout: single
|
||||
title: "DeepSpeed Sparse Attention"
|
||||
excerpt: ""
|
||||
categories: news
|
||||
new_post: true
|
||||
date: 2020-09-09 01:00:00
|
||||
---
|
||||
|
||||
Attention-based deep learning models such as the transformers are highly effective in capturing relationship between tokens in an input sequence, even across long distances. As a result, they are used with text, image, and sound-based inputs, where the sequence length can be in thousands of tokens. However, despite the effectiveness of attention modules to capture long term dependencies, in practice, their application to long sequence input is limited by compute and memory requirements of the attention computation that grow quadratically, `O(n^2)`, with the sequence length n.
|
||||
|
||||
To address this limitation, DeepSpeed offers a suite of sparse attention kernels --an instrumental technology that can reduce the compute and memory requirement of attention computation by orders-of-magnitude via block-sparse computation. The suite not only alleviates the memory bottleneck of attention calculation, but also performs sparse computation efficiently. Its APIs allow convenient integration with any transformer-based models. Along with providing a wide spectrum of sparsity structures, it has the flexibility of handling any user-defined block-sparse structures. More specifically, sparse attention (SA) can be designed to compute local attention between nearby tokens, or global attention via summary tokens computed with local attention. Moreover, SA can also allow random attention, or any combination of local, global, and random attention as shown in the following figure with blue, orange, and green blocks, respectively. As a result, SA decreases the memory footprint to `O(wn)`, in which `1 < w < n` is a parameter, whose value depends on the attention structure.
|
||||
|
||||
{: .align-center}
|
||||
|
||||
This library is PyTorch based and develops required kernels through [Triton](https://github.com/ptillet/triton) platform; kernels are not written in CUDA, which leaves the door open for CPU/OpenCL/Vulkan support in the future. The library is an extension to DeepSpeed and can be used through DeepSpeed as well as stand alone.
|
||||
Block-sparse computations handled by DeepSpeed Sparse Attention kernels are illustrated in following figures for forward and backward passes respectively. In the figures, `S` stands for a `block-sparse matrix` and `D` a `dense matrix`.
|
||||
|
||||
{: .align-center}
|
||||
|
||||
{: .align-center}
|
||||
|
||||
To learn more about Sparsity Config, and also how to use this library, please check our [tutorial](https://github.com/microsoft/DeepSpeed-internal/tree/master/docs/_tutorials/sparse_attention.md) that provides detailed information about it.
|
||||
|
||||
## Performance Results
|
||||
|
||||
* **Power over 10x longer sequences**
|
||||
In a pre-training experiment, we ran BERT model under three settings: dense, dense with activation checkpoint, and sparse (SA) with activation checkpoint. SA empowers 10x and 16x longer sequences comparing with dense for BERT base and large, respectively. Following figure shows the longest sequence length runnable in BERT base and large model; experiment is performed with batch size 1 on a single Nvidia V100 GPU-32GB memory.
|
||||
|
||||
{: .align-center}
|
||||
|
||||
* **up to 6.3x faster computation**
|
||||
We continued the pre-training experiment for different batch sizes and sequence lengths, using [BERT base/large](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert) and [Megatron GPT2](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). In this experiment we let the training to continue for 100 iteration and recorded the average time per last 30 iterations. SA reduces total computation comparing with dense and improves training speed: the boost is higher with increased sequence length and it is up to 6.3x faster for BERT base, 5.3x for BERT large, and 6.1x for GPT2. Following charts show these results.
|
||||
|
||||
{: .align-center}
|
||||
|
||||
{: .align-center}
|
||||
|
||||
{: .align-center}
|
||||
|
||||
* **higher accuracy**
|
||||
Related works along the line of sparse attention ([Sparse Transformer](https://arxiv.org/pdf/1904.10509.pdf), [Longformer](https://arxiv.org/pdf/2004.05150.pdf), [BigBird](https://arxiv.org/pdf/2007.14062.pdf)) have shown comparable or higher accuracy than full attention. Our experience is well aligned. In addition to lower memory overhead and faster computation, we also observe cases in production where SA reaches higher accuracy and faster convergence. The following chart illustrates accuracy of training a production model based on BERT for long document comprehension (2,048 sequence length). The experiment is performed in three settings: dense starting from scratch, SA starting from scratch, and SA continued training from a checkpoint of using dense with sequence length of 512. We have observed that, for pre-training from scratch, SA converges faster with higher accuracy comparing with dense. Furthermore, SA continuing training from a pre-trained checkpoint performs even better, with respect to both time and accuracy.
|
||||
|
||||
|
||||
{: .align-center}
|
||||
|
||||
* **flexibility to handle any block-sparse structure**
|
||||
DeepSpeed Sparse Attention suite does not target at any specific sparse structure but enables model scientists to explore any block sparse structure with efficient system support. Currently, we have added popular sparse structure like:
|
||||
* [Fixed](https://arxiv.org/pdf/1904.10509.pdf) (from OpenAI Sparse Transformer)
|
||||
* [BigBird](https://arxiv.org/pdf/2007.14062.pdf) (from Google)
|
||||
* BSLongformer (Block-Sparse implementation of [Longformer](https://arxiv.org/pdf/2004.05150.pdf) from AI2)
|
||||
We also define a template to have `variable` structure (top figure), which can be used to simply customize any block-sparse random/local/global attention pattern. In addition to this list, user can add any other sparsity structure as described in [tutorial](https://github.com/microsoft/DeepSpeed-internal/tree/master/docs/_tutorials/sparse_transformer.md) section.
|
||||
|
||||
|
||||
* **comparison with state of the art, Longformer**
|
||||
We compared SA with Longformer, a state-of-the-art sparse structure and implementation. In our experiment, SA uses `Fixed` sparsity, and two implementations have comparable accuracy. On system performance, SA outperforms Longformer both in training and inference:
|
||||
* 1.47x faster execution pre-training MLM on Wikitext103
|
||||
We ran an experiment following the [notebook](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) offered by Longformer. In this experiment, we pre-train an MLM model using RoBERTa-base checkpoint. This is done on 8 V100-SXM2 GPU. Following table shows the details of the result in which using DeepSpeed Sparse Attention shows 1.47x speed up.
|
||||
|
||||
|Model |Local Window Size |BPC |Train Step |Time Per Iteration |Time Improvement |Accuracy improvement |
|
||||
|-------------------|------------------|--------|------------|--------------------|------------------|----------------------|
|
||||
|RoBERTa Checkpoint | |2.5326 | |
|
||||
|Longformer |512 |2.6535 |0 | |1.47 |1.01 |
|
||||
|Sparse Attention | |2.6321 | | | | |
|
||||
|Longformer | |1.6708 |3k |1.6280 | |1.01 |
|
||||
|Sparse Attention | |1.6613 | |1.1059 | | |
|
||||
|Longformer |64 |5.7840 |0 | |1.31 |1.46 |
|
||||
|Sparse Attention | |3.9737 | | | | |
|
||||
|Longformer | |2.0466 |3k |1.4855 | |1.09 |
|
||||
|Sparse Attention | |1.8693 | |1.1372 | | |
|
||||
|
||||
|
||||
* 3.13x faster execution inference on BERT-Base
|
||||
Through our Long Document Comprehension application we described above, we also checked the inference time for different window sizes testing BERT model on a `2,048` Sequence Length and batch size `1`. In this experiment, we noticed up to `3.13X` speed up replacing Bert Attention with DeepSpeed Sparse Attention instead of Longformer Attention. Following table shows the complete result.
|
||||
|
||||
|Local Window Size |Time Improvement|
|
||||
|--------------------|----------------|
|
||||
|512 |3.13 |
|
||||
|256 |2.29 |
|
||||
|128 |2.16 |
|
||||
|64 |1.5 |
|
||||
|32 |1.24 |
|
||||
|16 |1.23 |
|
88
docs/_tutorials/sparse_attention.md
Normal file
@ -0,0 +1,88 @@
|
||||
---
|
||||
title: "DeepSpeed Sparse Attention"
|
||||
---
|
||||
|
||||
In this tutorial we describe how to use DeepSpeed Sparse Attention and its building-block kernels through DeepSpeed launcher or integrating individual kernels into your code.
|
||||
|
||||
**Note:** Currently DeepSpeed Sparse Attention can be used only on Nvidia V100 GPU using Cuda 10.1 or 10.2.
|
||||
{: .notice--warning}
|
||||
|
||||
## How to use
|
||||
DeepSpeed Sparse Attention can be used as a feature through DeepSpeed, or simply integrated with any Transformer model as a self-attention module alone. Further, the building block kernels, matrix multiplication and softmax can be used separately. To use sparse attention alone, you can simply install DeepSpeed and import any of the following modules from it; example:
|
||||
```python
|
||||
from deepspeed.ops.sparse_attention import SparseSelfAttention
|
||||
```
|
||||
Following we describe Sparse Attention modules:
|
||||
* **MatMul**: This module handles block-sparse matrix-matrix multiplication. Currently it supports SDD, DSD, and DDS as described in [DeepSpeed Sparse Attention](https://github.com/microsoft/DeepSpeed-internal/tree/master/docs/_posts/2020-09-09-sparse-attention.md) section.
|
||||
* **Softmax**: This module applies block sparse softmax. It handles both forward and backward pass.
|
||||
* **SparseSelfAttention**: This module uses MatMul and Softmax kernels and generates Context Layer output given Query, Keys and Values. It is a simplified version of common operations in any self-attention layer. It can also apply:
|
||||
* `Relative position embedding`
|
||||
* `Attention mask`
|
||||
* `Key padding mask`
|
||||
on the intermediate attention scores. For more details about SelfAttantion, please check [MultiHeadAttention](https://pytorch.org/docs/master/generated/torch.nn.MultiheadAttention.html#multiheadattention).
|
||||
* **BertSparseSelfAttention**: This module contains a simplified BertSelfAttention layer that can be used instead of original dense Bert Self-Attention layer. Our implementation is based on [DeepSpeedExample](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373-#L434).
|
||||
* **SparseAttentionUtils**: This module provides few utility functions to handle adapting pre-trained model with sparse attention:
|
||||
* `replace_model_self_attention_with_sparse_self_attention`: If you have currently loaded a model and want to replace self-attention module with sparse self-attention, you can simply use this function to handle it for you. It currently handles BERT and RoBERTa based pre-trained models, but you can extend it base on your model type if it is different from these two. You also need to extend the position embedding to handle new sequence length; this can be done using `extend_position_embedding` function.
|
||||
* `update_tokenizer_model_max_length`: This function simply updates maximum position embedding in your tokenizer with the new value.
|
||||
* `extend_position_embedding`: This function extends the position embedding based on the current values. For example, if you have a 128 max sequence length model and extending it to a 1k sequence length, it replicates current embeddings 8 times to initialize new embedding. Experimentally we have seen such initialization works much better than initializing from scratch; leads to faster convergence.
|
||||
* `pad_to_block_size`: This function pads input tokens and attention mask on sequence length dimension to be multiple of block size; this is a requirement for SA.
|
||||
* `unpad_sequence_output`: This function unpads sequence output if inputs of the model were padded.
|
||||
* **SparsityConfig**: this is an abstract class for sparsity structure. Any sparsity structure extends this class and writes its own `make_layout` function. DeepSpeed currently provides the following structures that will be described in next section:
|
||||
* `FixedSparsityConfig`
|
||||
* `BSLongformerSparsityConfig`
|
||||
* `BigBirdSparsityConfig`
|
||||
* `VariableSparsityConfig`
|
||||
|
||||
### BertSparseSelfAttention Example
|
||||
We have currently integrated Sparse Attention with our [bing_bert](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py) code that can be used as an example for integration. In this example, we replace, BertSelfAttention module with BertSparseSelfAttention. Using DeepSpeed launcher, you can enable sparse attention using `deepspeed_sparse_attention` argument ([example](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/ds_sa_train_bert_bsz64k_seq128.sh)) and add your desired sparsity config into the [DeepSpeed config file](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_bsz64k_lamb_config_seq128.json). In this example, we have used `fixed` sparsity mode. Further, you need to pad sequence dimension of `input_ids` and `attention_mask` to be multiple of sparse block size. As mentioned above, DeepSpeed provides utility functions for padding and unpadding and you can check our [example](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py) to see where and how pad and unpad the inputs or outputs of the model.
|
||||
|
||||
**Note:** Currently DeepSpeed Transformer Kernels do not support Sparse Attention. To use Sparse Attention, you need to disable Transformer Kernels!
|
||||
{: .notice--warning}
|
||||
|
||||
### Sparsity structures
|
||||
Following we describe supported sparsity structures, their parameter set and the flexibility of adding arbitrary sparsity pattern on the self-attention layer.
|
||||
* **SpasityConfig**:
|
||||
This module, is the parent class for all sparsity structures and contains the shared features of all sparsity structures. It takes the following parameters:
|
||||
* `num_heads`: an integer determining number of attention heads of the layer.
|
||||
* `block`: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such square blocks; `Block X Block`.
|
||||
* `different_layout_per_head`: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
|
||||
|
||||
* **Fixed** (FixedSparistyConfig):
|
||||
This structure is based on [Generative Modeling with Sparse Transformers](https://arxiv.org/abs/1904.10509) from OpenAI, in which local and global attention is fixed by the given parameters:
|
||||
* `num_local_blocks`: an integer determining the number of blocks in local attention window. As it is illustrated in the below figure (adapted from original paper), tokens in a local window, attend to all tokens local to them. In the case of autoregressive model, as in the figure, tokens attend to tokens appearing before them in the local window. And in the case of Masked model such as BERT, attention is bidirectional.
|
||||
* `num_global_blocks`: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; illustrated in the figure below as well.
|
||||
* `attention`: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
|
||||
* `horizontal_global_attention`: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
|
||||
* `num_different_global_patterns`: an integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative. For example, with 4 blocks constructing local window and global attention size of a single block, we can have 4 different versions in which the first, second, third, or forth block of each local window can be global representative of that window. This parameter determines how many of such patterns we want. Of course, there is a limitation based on `num_local_blocks` and `num_global_blocks`. Further, if you set this to more than one, you need to set `different_layout_per_head` to `True`.
|
||||
|
||||

|
||||
|
||||
* **BSLongformer** (BSLongformerSparistyConfig):
|
||||
This structure is an edited version of [Longformer: The Long-Document Transformer](https://arxiv.org/pdf/2004.05150.pdf), in which instead of single token-wise sparsity, we offer block of tokens sparsity. Parameters that define this patters are:
|
||||
* `num_sliding_window_blocks`: an integer determining the number of blocks in sliding local attention window.
|
||||
* `global_block_indices`: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if `global_block_end_indices` parameter is set, this parameter is used as starting index of each global window.
|
||||
* `global_block_end_indices`: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size as `global_block_indices` parameter, and combining this two parameters, for each index `i`, blocks from `global_block_indices[i]` to `global_block_end_indices[i]` (exclusive) are considered as global attention block.
|
||||
|
||||
* **BigBird** (BigBirdSparsityConfig):
|
||||
This structure is based on [Big Bird: Transformers for Longer Sequences](https://arxiv.org/pdf/2007.14062.pdf). It somehow combines the idea of `fixed` and `longformer` patterns along with random attention. Following parameters define this structure:
|
||||
* `num_random_blocks`: an integer determining how many blocks in each row block are attended randomly.
|
||||
* `num_sliding_window_blocks`: an integer determining the number of blocks in sliding local attention window.
|
||||
* `num_global_blocks`: an integer determining how many consecutive blocks, starting from index 0, are considered as global attention. Global block tokens will be attended by all other block tokens and will attend to all other block tokens as well.
|
||||
|
||||
* **Variable** (VariableSparsityConfig):
|
||||
This structure also combines the idea of local, global and random attention. Further, it has the flexibility of defining variable size local windows. Following is the list of parameters that define this structure:
|
||||
* `num_random_blocks`: an integer determining how many blocks in each row block are attended randomly.
|
||||
* `local_window_blocks`: a list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second number the second window, ..., and the last number determines the number of blocks in the remaining local windows.
|
||||
* `global_block_indices`: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if `global_block_end_indices` parameter is set, this parameter is used as starting index of each global window.
|
||||
* `global_block_end_indices`: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size as `global_block_indices` parameter, and combining this two parameters, for each index `i`, blocks from `global_block_indices[i]` to `global_block_end_indices[i]` (exclusive) are considered as global attention block.
|
||||
* `attention`: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
|
||||
* `horizontal_global_attention`: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks
|
||||
Figure bellow illustrates an example of `variable` sparsity, in which blue, orange and green blocks illustrate local, global, and random attention blocks respectively.
|
||||
|
||||

|
||||
|
||||
Further, we provide a `dense` pattern (`DenseSparsityConfig`), that can be used for the sake of testing while it represents the full attention.
|
||||
|
||||
|
||||
### How to expand block-base sparsity patterns
|
||||
Our building block kernels, block-based `MatMul` & `Softmax`, can accept any block-based sparsity. This provides the flexibility to apply any block-based sparsity pattern to attention score. To define and apply a new sparsity pattern, you can simply follow any of the above sparsity structures. You need to add a new class that expands `SparsityConfig` and define `make_layout` function based on how your sparsity is structured. You can add any extra parameters you may need or just use default parameters of the parent class.
|
@ -19,6 +19,9 @@ training](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html).
|
||||
|
||||
To use transformer kernel for training a model, you should Integrate DeepSpeed into your training script using the [Getting Started](/getting-started/) guide.
|
||||
|
||||
**Note:** Currently DeepSpeed Transformer Kernels do not support Sparse Attention. To use Sparse Attention, you need to disable Transformer Kernels!
|
||||
{: .notice--warning}
|
||||
|
||||
### **Integrate Transformer Kernel**
|
||||
|
||||
First of all, you need to integrate transformer kernel into the top-level model. Here, we show an example of instantiating the transformer kernel using the Pre-LN BERT-Large configuration settings. This configuration has 24 layers with 1024 hidden-dimension and uses the sequence length of 128 and batch size of 64. To add all these layers, we copy the same layer specification `num_hidden_layer` times with different IDs inside a ModuleList.
|
||||
|
BIN
docs/assets/images/sa_backward_pass.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
docs/assets/images/sa_bert_base_time_result.png
Normal file
After Width: | Height: | Size: 64 KiB |
BIN
docs/assets/images/sa_bert_large_time_result.png
Normal file
After Width: | Height: | Size: 60 KiB |
BIN
docs/assets/images/sa_fixed_sparsity_structure.png
Normal file
After Width: | Height: | Size: 69 KiB |
BIN
docs/assets/images/sa_forward_pass.png
Normal file
After Width: | Height: | Size: 34 KiB |
BIN
docs/assets/images/sa_gpt2_time_result.png
Normal file
After Width: | Height: | Size: 61 KiB |
BIN
docs/assets/images/sa_long_document_comprehension_result.png
Normal file
After Width: | Height: | Size: 58 KiB |
BIN
docs/assets/images/sa_maximum_sequence_runnable_on_bert.png
Normal file
After Width: | Height: | Size: 33 KiB |
BIN
docs/assets/images/sa_variable_sparsity_structure.png
Normal file
After Width: | Height: | Size: 99 KiB |
BIN
docs/assets/images/variable_sparsity_pattern.png
Normal file
After Width: | Height: | Size: 42 KiB |
26
install.sh
@ -147,11 +147,6 @@ if [ "$no_clean" == "0" ]; then
|
||||
rm_if_exist third_party/apex/apex.egg-info
|
||||
fi
|
||||
|
||||
echo "Updating git hash/branch info"
|
||||
echo "git_hash = '$(git rev-parse --short HEAD)'" > deepspeed/git_version_info.py
|
||||
echo "git_branch = '$(git rev-parse --abbrev-ref HEAD)'" >> deepspeed/git_version_info.py
|
||||
cat deepspeed/git_version_info.py
|
||||
|
||||
if [ "$pip_sudo" == "1" ]; then
|
||||
PIP_SUDO="sudo -H"
|
||||
else
|
||||
@ -159,7 +154,7 @@ else
|
||||
fi
|
||||
|
||||
if [ "$pip_mirror" != "" ]; then
|
||||
PIP_INSTALL="pip install -v -i $pip_mirror"
|
||||
PIP_INSTALL="pip install --use-feature=2020-resolver -v -i $pip_mirror"
|
||||
else
|
||||
PIP_INSTALL="pip install -v"
|
||||
fi
|
||||
@ -169,10 +164,10 @@ if [ ! -f $hostfile ]; then
|
||||
local_only=1
|
||||
fi
|
||||
|
||||
if [ "$skip_requirements" == "0" ]; then
|
||||
# Ensure dependencies are installed locally
|
||||
$PIP_SUDO $PIP_INSTALL -r requirements.txt
|
||||
fi
|
||||
#if [ "$skip_requirements" == "0" ]; then
|
||||
# # Ensure dependencies are installed locally
|
||||
# $PIP_SUDO $PIP_INSTALL -r requirements.txt
|
||||
#fi
|
||||
|
||||
# Build wheels
|
||||
if [ "$third_party_install" == "1" ]; then
|
||||
@ -205,7 +200,8 @@ if [ "$local_only" == "1" ]; then
|
||||
echo "Installing deepspeed"
|
||||
$PIP_SUDO pip uninstall -y deepspeed
|
||||
$PIP_SUDO $PIP_INSTALL dist/deepspeed*.whl
|
||||
python basic_install_test.py
|
||||
# -I to exclude local directory files
|
||||
python -I basic_install_test.py
|
||||
if [ $? == 0 ]; then
|
||||
echo "Installation is successful"
|
||||
else
|
||||
@ -224,10 +220,10 @@ else
|
||||
tmp_wheel_path="/tmp/deepspeed_wheels"
|
||||
|
||||
pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; else mkdir -pv $tmp_wheel_path; fi"
|
||||
pdcp -w $hosts requirements.txt ${tmp_wheel_path}/
|
||||
if [ "$skip_requirements" == "0" ]; then
|
||||
pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
|
||||
fi
|
||||
#pdcp -w $hosts requirements/*.txt ${tmp_wheel_path}/
|
||||
#if [ "$skip_requirements" == "0" ]; then
|
||||
# pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL -r ${tmp_wheel_path}/requirements.txt"
|
||||
#fi
|
||||
if [ "$third_party_install" == "1" ]; then
|
||||
pdsh -w $hosts "$PIP_SUDO pip uninstall -y apex"
|
||||
pdcp -w $hosts third_party/apex/dist/apex*.whl $tmp_wheel_path/
|
||||
|
4
requirements/requirements-dev.txt
Normal file
@ -0,0 +1,4 @@
|
||||
pytest
|
||||
pytest-forked
|
||||
pre-commit
|
||||
clang-format
|
1
requirements/requirements-sparse-attn.txt
Normal file
@ -0,0 +1 @@
|
||||
triton>=0.2.2
|
@ -3,7 +3,3 @@ torchvision>=0.4.0
|
||||
tqdm
|
||||
psutil
|
||||
tensorboardX==1.8
|
||||
pytest
|
||||
pytest-forked
|
||||
pre-commit
|
||||
clang-format
|
265
setup.py
@ -10,8 +10,53 @@ The wheel will be located at: dist/*.whl
|
||||
|
||||
import os
|
||||
import torch
|
||||
import subprocess
|
||||
import warnings
|
||||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
||||
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension
|
||||
|
||||
VERSION = "0.3.0"
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
with open(path, 'r') as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
install_requires = fetch_requirements('requirements/requirements.txt')
|
||||
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
|
||||
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
|
||||
|
||||
# Build environment variables for custom builds
|
||||
DS_BUILD_LAMB_MASK = 1
|
||||
DS_BUILD_TRANSFORMER_MASK = 10
|
||||
DS_BUILD_SPARSE_ATTN_MASK = 100
|
||||
|
||||
# Allow for build_cuda to turn on or off all ops
|
||||
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK
|
||||
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS
|
||||
|
||||
# Set default of each op based on if build_cuda is set
|
||||
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
|
||||
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
|
||||
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
|
||||
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
|
||||
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
|
||||
0)) * DS_BUILD_SPARSE_ATTN_MASK
|
||||
|
||||
# Final effective mask is the bitwise OR of each op
|
||||
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN)
|
||||
|
||||
install_ops = []
|
||||
if BUILD_MASK & DS_BUILD_LAMB:
|
||||
install_ops.append('lamb')
|
||||
if BUILD_MASK & DS_BUILD_TRANSFORMER:
|
||||
install_ops.append('transformer')
|
||||
if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
|
||||
install_ops.append('sparse-attn')
|
||||
if len(install_ops) == 0:
|
||||
print("Building without any cuda/cpp extensions")
|
||||
print(f'BUILD_MASK={BUILD_MASK}, install_ops={install_ops}')
|
||||
|
||||
cmdclass = {}
|
||||
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
|
||||
@ -40,95 +85,163 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
|
||||
version_ge_1_5 = ['-DVERSION_GE_1_5']
|
||||
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
|
||||
|
||||
ext_modules = [
|
||||
CUDAExtension(
|
||||
name='deepspeed_lamb_cuda',
|
||||
sources=['csrc/lamb/fused_lamb_cuda.cpp',
|
||||
'csrc/lamb/fused_lamb_cuda_kernel.cu'],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': [
|
||||
'-O3',
|
||||
] + version_dependent_macros,
|
||||
'nvcc': ['-O3',
|
||||
'--use_fast_math'] + version_dependent_macros
|
||||
}),
|
||||
CUDAExtension(name='deepspeed_transformer_cuda',
|
||||
sources=[
|
||||
'csrc/transformer/ds_transformer_cuda.cpp',
|
||||
'csrc/transformer/cublas_wrappers.cu',
|
||||
'csrc/transformer/transform_kernels.cu',
|
||||
'csrc/transformer/gelu_kernels.cu',
|
||||
'csrc/transformer/dropout_kernels.cu',
|
||||
'csrc/transformer/normalize_kernels.cu',
|
||||
'csrc/transformer/softmax_kernels.cu',
|
||||
'csrc/transformer/general_kernels.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3',
|
||||
ext_modules = []
|
||||
|
||||
## Lamb ##
|
||||
if BUILD_MASK & DS_BUILD_LAMB:
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='deepspeed.ops.lamb.fused_lamb_cuda',
|
||||
sources=[
|
||||
'csrc/lamb/fused_lamb_cuda.cpp',
|
||||
'csrc/lamb/fused_lamb_cuda_kernel.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': [
|
||||
'-O3',
|
||||
] + version_dependent_macros,
|
||||
'nvcc': ['-O3',
|
||||
'--use_fast_math'] + version_dependent_macros
|
||||
}))
|
||||
|
||||
## Transformer ##
|
||||
if BUILD_MASK & DS_BUILD_TRANSFORMER:
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='deepspeed.ops.transformer.transformer_cuda',
|
||||
sources=[
|
||||
'csrc/transformer/ds_transformer_cuda.cpp',
|
||||
'csrc/transformer/cublas_wrappers.cu',
|
||||
'csrc/transformer/transform_kernels.cu',
|
||||
'csrc/transformer/gelu_kernels.cu',
|
||||
'csrc/transformer/dropout_kernels.cu',
|
||||
'csrc/transformer/normalize_kernels.cu',
|
||||
'csrc/transformer/softmax_kernels.cu',
|
||||
'csrc/transformer/general_kernels.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3',
|
||||
'-std=c++14',
|
||||
'-g',
|
||||
'-Wno-reorder'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-gencode',
|
||||
'arch=compute_61,code=compute_61',
|
||||
'-gencode',
|
||||
'arch=compute_70,code=compute_70',
|
||||
'-std=c++14',
|
||||
'-g',
|
||||
'-Wno-reorder'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-gencode',
|
||||
'arch=compute_61,code=compute_61',
|
||||
'-gencode',
|
||||
'arch=compute_70,code=compute_70',
|
||||
'-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__'
|
||||
]
|
||||
}),
|
||||
CUDAExtension(name='deepspeed_stochastic_transformer_cuda',
|
||||
sources=[
|
||||
'csrc/transformer/ds_transformer_cuda.cpp',
|
||||
'csrc/transformer/cublas_wrappers.cu',
|
||||
'csrc/transformer/transform_kernels.cu',
|
||||
'csrc/transformer/gelu_kernels.cu',
|
||||
'csrc/transformer/dropout_kernels.cu',
|
||||
'csrc/transformer/normalize_kernels.cu',
|
||||
'csrc/transformer/softmax_kernels.cu',
|
||||
'csrc/transformer/general_kernels.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__'
|
||||
]
|
||||
}))
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='deepspeed.ops.transformer.stochastic_transformer_cuda',
|
||||
sources=[
|
||||
'csrc/transformer/ds_transformer_cuda.cpp',
|
||||
'csrc/transformer/cublas_wrappers.cu',
|
||||
'csrc/transformer/transform_kernels.cu',
|
||||
'csrc/transformer/gelu_kernels.cu',
|
||||
'csrc/transformer/dropout_kernels.cu',
|
||||
'csrc/transformer/normalize_kernels.cu',
|
||||
'csrc/transformer/softmax_kernels.cu',
|
||||
'csrc/transformer/general_kernels.cu'
|
||||
],
|
||||
include_dirs=['csrc/includes'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3',
|
||||
'-std=c++14',
|
||||
'-g',
|
||||
'-Wno-reorder'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-gencode',
|
||||
'arch=compute_61,code=compute_61',
|
||||
'-gencode',
|
||||
'arch=compute_70,code=compute_70',
|
||||
'-std=c++14',
|
||||
'-g',
|
||||
'-Wno-reorder'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-gencode',
|
||||
'arch=compute_61,code=compute_61',
|
||||
'-gencode',
|
||||
'arch=compute_70,code=compute_70',
|
||||
'-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-D__STOCHASTIC_MODE__'
|
||||
]
|
||||
}),
|
||||
]
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-D__STOCHASTIC_MODE__'
|
||||
]
|
||||
}))
|
||||
|
||||
|
||||
def command_exists(cmd):
|
||||
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
|
||||
return result.wait() == 0
|
||||
|
||||
|
||||
## Sparse transformer ##
|
||||
if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
|
||||
# Check to see if llvm and cmake are installed since they are dependencies
|
||||
required_commands = ['llc-9', 'cmake']
|
||||
|
||||
command_status = list(map(command_exists, required_commands))
|
||||
if not all(command_status):
|
||||
zipped_status = list(zip(required_commands, command_status))
|
||||
warnings.warn(
|
||||
f'Missing non-python requirements, please install the missing packages: {zipped_status}'
|
||||
)
|
||||
warnings.warn(
|
||||
'Skipping sparse attention installation due to missing required packages')
|
||||
elif TORCH_MAJOR == 1 and TORCH_MINOR >= 5:
|
||||
ext_modules.append(
|
||||
CppExtension(name='deepspeed.ops.sparse_attention.cpp_utils',
|
||||
sources=['csrc/sparse_attention/utils.cpp'],
|
||||
extra_compile_args={'cxx': ['-O2',
|
||||
'-fopenmp']}))
|
||||
# Add sparse attention requirements
|
||||
install_requires += sparse_attn_requires
|
||||
else:
|
||||
warnings.warn('Unable to meet requirements to install sparse attention')
|
||||
|
||||
# Add development requirements
|
||||
install_requires += dev_requires
|
||||
|
||||
# Write out version/git info
|
||||
git_hash_cmd = "git rev-parse --short HEAD"
|
||||
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
|
||||
if command_exists('git'):
|
||||
result = subprocess.check_output(git_hash_cmd, shell=True)
|
||||
git_hash = result.decode('utf-8').strip()
|
||||
result = subprocess.check_output(git_branch_cmd, shell=True)
|
||||
git_branch = result.decode('utf-8').strip()
|
||||
else:
|
||||
git_hash = "unknown"
|
||||
git_branch = "unknown"
|
||||
print(f"version={VERSION}+{git_hash}, git_hash={git_hash}, git_branch={git_branch}")
|
||||
with open('deepspeed/git_version_info.py', 'w') as fd:
|
||||
fd.write(f"version='{VERSION}+{git_hash}'\n")
|
||||
fd.write(f"git_hash='{git_hash}'\n")
|
||||
fd.write(f"git_branch='{git_branch}'\n")
|
||||
|
||||
print(f'install_requires={install_requires}')
|
||||
|
||||
setup(name='deepspeed',
|
||||
version='0.2.0',
|
||||
version=f"{VERSION}+{git_hash}",
|
||||
description='DeepSpeed library',
|
||||
author='DeepSpeed Team',
|
||||
author_email='deepspeed@microsoft.com',
|
||||
url='http://aka.ms/deepspeed',
|
||||
install_requires=install_requires,
|
||||
packages=find_packages(exclude=["docker",
|
||||
"third_party",
|
||||
"csrc"]),
|
||||
package_data={'deepspeed.ops.sparse_attention.trsrc': ['*.tr']},
|
||||
scripts=['bin/deepspeed',
|
||||
'bin/deepspeed.pt',
|
||||
'bin/ds',
|
||||
'bin/ds_ssh'],
|
||||
classifiers=['Programming Language :: Python :: 3.6'],
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8'
|
||||
],
|
||||
license='MIT',
|
||||
ext_modules=ext_modules,
|
||||
cmdclass=cmdclass)
|
||||
|
342
tests/unit/skip_sparse_attention.py
Executable file
@ -0,0 +1,342 @@
|
||||
# DeepSpeed note, some parts of code taken & adapted from commit c368a9fd1b2c9dee4cc94de9a6bb0be3d447be41
|
||||
# https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_softmax.py
|
||||
# https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_matmul.py
|
||||
# https://github.com/ptillet/torch-blocksparse/blob/master/tests/utils
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def test_sparse_attention_module_availability():
|
||||
try:
|
||||
from deepspeed.ops import sparse_attention
|
||||
except ImportError:
|
||||
print("Sparse Attention Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_matmul_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import MatMul
|
||||
except ImportError:
|
||||
print("Sparse MatMul Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_softmax_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import Softmax
|
||||
except ImportError:
|
||||
print("Sparse Softmax Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_sparsityconfig_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import SparsityConfig
|
||||
except ImportError:
|
||||
print("SparsityConfig Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_densesparsityconfig_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import DenseSparsityConfig
|
||||
except ImportError:
|
||||
print("DenseSparsityConfig Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_fixedsparsityconfig_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import FixedSparsityConfig
|
||||
except ImportError:
|
||||
print("FixedSparsityConfig Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_variablesparsityconfig_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import VariableSparsityConfig
|
||||
except ImportError:
|
||||
print("VariableSparsityConfig Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_bigbirdsparsityconfig_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import BigBirdSparsityConfig
|
||||
except ImportError:
|
||||
print("BigBirdSparsityConfig Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_bslongformersparsityconfig_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig
|
||||
except ImportError:
|
||||
print("BSLongformerSparsityConfig Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_sparseselfattention_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import SparseSelfAttention
|
||||
except ImportError:
|
||||
print("SparseSelfAttention Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_bertsparseselfattention_module_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import BertSparseSelfAttention
|
||||
except ImportError:
|
||||
print("BertSparseSelfAttention Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_sparseattentionutils_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import SparseAttentionUtils
|
||||
except ImportError:
|
||||
print("SparseAttentionUtils Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_cpp_utils_availability():
|
||||
try:
|
||||
from deepspeed.ops.sparse_attention import cpp_utils
|
||||
except ImportError:
|
||||
print("Sparse Attention cpp_utils Module is not installed!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def dense_to_sparse(w, mask, block):
|
||||
"""Converts dense matrix with explicit zeros to sparse matrix
|
||||
"""
|
||||
Z = w.size(0)
|
||||
ret = torch.empty((Z, mask.sum(), block, block), dtype=w.dtype, device=w.device)
|
||||
nnz = mask.nonzero()
|
||||
h, i, j = nnz[:, 0], nnz[:, 1], nnz[:, 2]
|
||||
for zz in range(Z):
|
||||
for idx, (hh, ii, jj) in enumerate(zip(h, i, j)):
|
||||
ret[zz, idx, :, :] = w[zz, hh, ii*block: (ii+1)*block, jj*block: (jj+1)*block]
|
||||
return ret
|
||||
|
||||
|
||||
def sparse_to_dense(w, mask, block, zero=0):
|
||||
"""Converts sparse matrix to dense matrix with explicit zeros
|
||||
"""
|
||||
maskedw = w.clone()
|
||||
for bz, wz in enumerate(range(0, w.size(0))):
|
||||
for bh, wh in enumerate(range(0, w.size(1))):
|
||||
for bi, wi in enumerate(range(0, w.size(2), block)):
|
||||
for bj, wj in enumerate(range(0, w.size(3), block)):
|
||||
if mask[bh, bi, bj] == 0:
|
||||
maskedw[wz, wh, wi:wi + block, wj:wj + block] = zero
|
||||
#maskedw[wz, wh, wi : wi+block, wj : wj+block] *= mask[bh, bi, bj]
|
||||
return maskedw
|
||||
|
||||
|
||||
def allclose(x, y):
|
||||
assert x.dtype == y.dtype
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
|
||||
return torch.allclose(x, y, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def make_layout(rho, shape):
|
||||
probs = torch.Tensor([rho, 1 - rho])
|
||||
generator = torch.distributions.categorical.Categorical(probs)
|
||||
layout = generator.sample(shape)
|
||||
return layout
|
||||
|
||||
|
||||
def run_softmax_reference(x, scale, dx, kp_mask, attn_mask, layout, block):
|
||||
x = sparse_to_dense(x, layout, block, zero=float('-inf'))
|
||||
x.retain_grad()
|
||||
if kp_mask is not None:
|
||||
bcattn_mask = attn_mask[None, None, :, :] + torch.zeros_like(x)
|
||||
x[bcattn_mask == 0] = float('-inf')
|
||||
y = torch.softmax(x * scale + kp_mask[:, None, None, :], -1)
|
||||
else:
|
||||
y = torch.softmax(x * scale, -1)
|
||||
y.backward(dx)
|
||||
dx = x.grad.clone()
|
||||
dx = dense_to_sparse(dx, layout, block)
|
||||
y = dense_to_sparse(y, layout, block)
|
||||
return y, dx
|
||||
|
||||
|
||||
def run_softmax_sparse(x, scale, dx, kp_mask, attn_mask, layout, block):
|
||||
from deepspeed.ops.sparse_attention import Softmax
|
||||
sparse_softmax = Softmax(layout, block, bench=False)
|
||||
dx = dense_to_sparse(dx, layout, block)
|
||||
x = dense_to_sparse(x, layout, block)
|
||||
x.retain_grad()
|
||||
y = sparse_softmax(x,
|
||||
scale=scale,
|
||||
key_padding_mask=kp_mask,
|
||||
key_padding_mask_mode='add',
|
||||
attn_mask=attn_mask,
|
||||
attn_mask_mode='mul')
|
||||
y.backward(dx)
|
||||
dx = x.grad.clone()
|
||||
x.grad.zero_()
|
||||
return x, dx
|
||||
|
||||
|
||||
def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layout=None):
|
||||
if layout is None:
|
||||
layout = make_layout(rho, (H, M // block, N // block))
|
||||
if dense_x:
|
||||
x = torch.rand((Z, H, M, N), dtype=dtype, requires_grad=True, device='cuda')
|
||||
else:
|
||||
x = torch.rand((Z,
|
||||
layout.sum(),
|
||||
block,
|
||||
block),
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
device='cuda')
|
||||
dx = torch.rand_like(x)
|
||||
bool_attn_mask = torch.randint(low=0,
|
||||
high=2,
|
||||
size=(N,
|
||||
N),
|
||||
dtype=torch.bool,
|
||||
requires_grad=False,
|
||||
device='cuda')
|
||||
fp_attn_mask = bool_attn_mask.type(dtype)
|
||||
kp_mask = torch.randint(low=0,
|
||||
high=2,
|
||||
size=(Z,
|
||||
N),
|
||||
dtype=dtype,
|
||||
requires_grad=False,
|
||||
device='cuda')
|
||||
kp_mask[kp_mask == 1.] = float('-inf')
|
||||
return layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask
|
||||
|
||||
|
||||
def _skip_on_cuda_compatability():
|
||||
if torch.cuda.get_device_capability()[0] != 7:
|
||||
pytest.skip("needs compute capability 7; v100")
|
||||
cuda_major = int(torch.version.cuda.split('.')[0]) * 10
|
||||
cuda_minor = int(torch.version.cuda.split('.')[1])
|
||||
cuda_version = cuda_major + cuda_minor
|
||||
if cuda_version != 101 and cuda_version != 102:
|
||||
pytest.skip("requires cuda 10.1 or 10.2")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32])
|
||||
@pytest.mark.parametrize("width", [256, 576])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_softmax(block, width, dtype):
|
||||
_skip_on_cuda_compatability()
|
||||
Z = 2
|
||||
H = 4
|
||||
scale = 0.4
|
||||
rho = 0.4
|
||||
M = N = width
|
||||
layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask = init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, layout=None)
|
||||
ref_y, ref_dx = run_softmax_reference(x, scale, dx, kp_mask, bool_attn_mask, layout, block)
|
||||
st_y, st_dx = run_softmax_sparse(x, scale, dx, kp_mask, fp_attn_mask, layout, block)
|
||||
assert allclose(ref_y, st_y)
|
||||
assert allclose(ref_dx, st_dx)
|
||||
|
||||
|
||||
def run_matmul_reference(x, w, mode, trans_a, trans_b, layout, block, dy):
|
||||
x = sparse_to_dense(x, layout, block) if mode == 'dsd' else x
|
||||
w = sparse_to_dense(w, layout, block) if mode == 'dds' else w
|
||||
x.retain_grad()
|
||||
w.retain_grad()
|
||||
xx = x.transpose(2, 3) if trans_a else x
|
||||
ww = w.transpose(2, 3) if trans_b else w
|
||||
y = torch.matmul(xx, ww)
|
||||
y = sparse_to_dense(y, layout, block) if mode == 'sdd' else y
|
||||
y.backward(dy)
|
||||
dx = x.grad.clone()
|
||||
dw = w.grad.clone()
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
y = dense_to_sparse(y, layout, block) if mode == 'sdd' else y
|
||||
dx = dense_to_sparse(dx, layout, block) if mode == 'dsd' else dx
|
||||
dw = dense_to_sparse(dw, layout, block) if mode == 'dds' else dw
|
||||
return y, dx, dw
|
||||
|
||||
|
||||
def run_matmul_sparse(x, w, mode, trans_a, trans_b, layout, block, dy):
|
||||
from deepspeed.ops.sparse_attention import MatMul
|
||||
x = dense_to_sparse(x, layout, block) if mode == 'dsd' else x
|
||||
w = dense_to_sparse(w, layout, block) if mode == 'dds' else w
|
||||
dy = dense_to_sparse(dy, layout, block) if mode == 'sdd' else dy
|
||||
op = MatMul(layout, block, mode, trans_a=trans_a, trans_b=trans_b)
|
||||
x.retain_grad()
|
||||
w.retain_grad()
|
||||
y = op(x, w)
|
||||
y.backward(dy)
|
||||
dx = x.grad.clone()
|
||||
dw = w.grad.clone()
|
||||
x.grad.zero_()
|
||||
return y, dx, dw
|
||||
|
||||
|
||||
def init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout):
|
||||
torch.manual_seed(1)
|
||||
AS0 = K if trans_a else M
|
||||
AS1 = M if trans_a else K
|
||||
BS0 = N if trans_b else K
|
||||
BS1 = K if trans_b else N
|
||||
shape = {'sdd': (M, N), 'dsd': (AS0, AS1), 'dds': (BS0, BS1)}[mode]
|
||||
x = torch.rand((Z, H, AS0, AS1), dtype=dtype, requires_grad=True, device='cuda')
|
||||
w = torch.rand((Z, H, BS0, BS1), dtype=dtype, requires_grad=True, device='cuda')
|
||||
dy = torch.rand((Z, H, M, N), dtype=dtype, device='cuda')
|
||||
if layout is None:
|
||||
layout = make_layout(rho, (H, shape[0] // block, shape[1] // block))
|
||||
else:
|
||||
assert list(layout.shape) == [H, shape[0] // block, shape[1] // block]
|
||||
x.retain_grad()
|
||||
w.retain_grad()
|
||||
return x, w, dy, shape, layout
|
||||
|
||||
testdata = [
|
||||
(16, dtype, mode, trans_a, trans_b)\
|
||||
for dtype in [torch.float16, torch.float32]\
|
||||
for mode in ['sdd', 'dsd', 'dds']\
|
||||
for trans_a in [False, True]\
|
||||
for trans_b in [False, True]\
|
||||
] + [
|
||||
(block, torch.float16, mode, False, False)\
|
||||
for block in [16, 32, 64]\
|
||||
for mode in ['sdd', 'dsd', 'dds']\
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block, dtype, mode, trans_a, trans_b", testdata)
|
||||
def test_matmul(block, dtype, mode, trans_a, trans_b):
|
||||
_skip_on_cuda_compatability()
|
||||
Z = 3
|
||||
H = 2
|
||||
M = 128
|
||||
N = 256
|
||||
K = 192
|
||||
rho = 0.5
|
||||
x, w, dy, shape, layout = init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout=None)
|
||||
ref_y, ref_dx, ref_dw = run_matmul_reference(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
|
||||
st_y, st_dx, st_dw = run_matmul_sparse(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
|
||||
assert allclose(ref_y, st_y)
|
||||
assert allclose(ref_dx, st_dx)
|
||||
assert allclose(ref_dw, st_dw)
|
@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
|
||||
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
|
||||
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
|
||||
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
|
||||
|
||||
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
|
||||
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
|
||||
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
|
||||
|
||||
import argparse
|
||||
import pytest
|
||||
|
@ -9,7 +9,7 @@ import torch.distributed as dist
|
||||
|
||||
# A test on its own
|
||||
import deepspeed
|
||||
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
|
||||
from deepspeed.runtime.config import DeepSpeedConfig
|
||||
|
||||
|
||||
def test_cuda():
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import random
|
||||
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
|
||||
from deepspeed.runtime.csr_tensor import CSRTensor
|
||||
|
||||
|
||||
def test_csr_addition_self():
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import os
|
||||
import json
|
||||
from deepspeed.pt import deepspeed_config as ds_config
|
||||
from deepspeed.runtime import config as ds_config
|
||||
|
||||
|
||||
def test_only_required_fields(tmpdir):
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from deepspeed.pt import deepspeed_run as dsrun
|
||||
from deepspeed.launcher import runner as dsrun
|
||||
|
||||
|
||||
def test_parser_mutual_exclusive():
|
||||
|