[inference] test suite for ds-kernels (bert, roberta, gpt2, gpt-neo, gpt-j) (#1992)

Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
This commit is contained in:
Jeff Rasley
2022-06-15 14:21:19 -07:00
committed by GitHub
parent e6f444aee2
commit b666d5cd73
18 changed files with 528 additions and 178 deletions

View File

@ -63,5 +63,5 @@ jobs:
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 -m 'not sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -m 'sequential' unit/

63
.github/workflows/nv-inference.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: nv-inference
on:
push:
branches:
- 'master'
- 'staging**'
paths-ignore:
- 'docs/**'
pull_request:
paths-ignore:
- 'docs/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
unit-tests:
runs-on: [self-hosted, nvidia, cu111, v100]
steps:
- uses: actions/checkout@v2
- name: environment
run: |
nvidia-smi
which python
python --version
which nvcc
nvcc --version
pip install --upgrade pip
pip uninstall --yes torch torchvision
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
# git checkout 1cc453d33
git rev-parse --short HEAD
pip uninstall --yes transformers
pip install .
- name: Python environment
run: |
pip list
- name: Install deepspeed
run: |
pip uninstall --yes deepspeed
pip install .[dev,1bit,autotuning,sparse_attn,inf]
ds_report
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'inference' unit/

52
.github/workflows/nv-nightly.yml vendored Normal file
View File

@ -0,0 +1,52 @@
name: nv-nightly
on:
schedule:
- cron: "0 0 * * *"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
unit-tests:
runs-on: [self-hosted, nvidia, cu111, v100]
steps:
- uses: actions/checkout@v2
- name: environment
run: |
nvidia-smi
which python
python --version
which nvcc
nvcc --version
pip install --upgrade pip
pip uninstall --yes torch torchvision
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
# git checkout 1cc453d33
git rev-parse --short HEAD
pip uninstall --yes transformers
pip install .
- name: Install deepspeed
run: |
pip uninstall --yes deepspeed
pip install .[dev,1bit,autotuning,sparse_attn]
ds_report
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'nightly' unit/

View File

@ -60,5 +60,5 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/

View File

@ -53,5 +53,5 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/

View File

@ -60,5 +60,5 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/

View File

@ -57,6 +57,8 @@ jobs:
pip install .[testing]
# find reqs used in ds integration tests
find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec grep -v 'torch' {} \; | xargs -I {} pip install --upgrade {}
# force datasets version due to issues
pip install datasets==2.2.2
# force protobuf version due to issues
pip install "protobuf<4.21.0"
pip list

1
bin/dsr Symbolic link
View File

@ -0,0 +1 @@
ds_report

View File

@ -174,7 +174,8 @@ __global__ void fused_bias_residual(float* input,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size)
int mp_size,
bool preln)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
@ -189,12 +190,17 @@ __global__ void fused_bias_residual(float* input,
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
if (preln) {
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
} else {
data.x = data.x + out.x + bias_data.x;
data.y = data.y + out.y + bias_data.y;
data.z = data.z + out.z + bias_data.z;
data.w = data.w + out.w + bias_data.w;
}
output_cast[offset] = data;
}
}
@ -206,7 +212,8 @@ __global__ void fused_bias_residual(__half* input,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size)
int mp_size,
bool preln)
{
#ifdef HALF_PRECISION_AVAILABLE
@ -248,15 +255,21 @@ __global__ void fused_bias_residual(__half* input,
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x =
(high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y =
(high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
if (preln) {
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x = (high_data.x + high_res.x) * mp_size +
(high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y = (high_data.y + high_res.y) * mp_size +
(high_out.y + (high_bias.y + attn_high_bias.y));
} else {
low_data.x = (low_data.x + low_out.x + low_bias.x);
low_data.y = (low_data.y + low_out.y + low_bias.y);
high_data.x = (high_data.x + high_out.x + high_bias.x);
high_data.y = (high_data.y + high_out.y + high_bias.y);
}
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
@ -274,6 +287,7 @@ void launch_bias_residual(T* input,
int batch,
int hidden_dim,
int mp_size,
bool preln,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
@ -281,20 +295,13 @@ void launch_bias_residual(T* input,
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size, preln);
}
template void
launch_bias_residual<float>(float*, float*, float*, float*, float*, int, int, int, cudaStream_t);
template void launch_bias_residual<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
cudaStream_t);
template void launch_bias_residual<
float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t);
template void launch_bias_residual<
__half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t);
__global__ void gptj_residual_add(float* input,
float* output,

View File

@ -787,17 +787,17 @@ at::Tensor ds_vector_matmul_int8(at::Tensor& input,
}
template <typename T>
void mlp_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
at::Tensor mlp_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
@ -840,18 +840,19 @@ void mlp_unfused_cublas(at::Tensor& output,
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
at::Tensor ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
@ -863,19 +864,19 @@ at::Tensor ds_mlp_gemm(at::Tensor& input,
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight,
bias,
gamma,
beta,
epsilon,
preLayerNorm,
mlp_after_attn);
auto res_add = mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight,
bias,
gamma,
beta,
epsilon,
preLayerNorm,
mlp_after_attn);
return output;
return {output, res_add};
}
template <typename T>
@ -1001,7 +1002,8 @@ void residual_add_bias(at::Tensor& output,
at::Tensor& attention_b,
int mp_size,
bool mlp_after_attn,
bool add_bias)
bool add_bias,
bool preln)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
@ -1017,6 +1019,7 @@ void residual_add_bias(at::Tensor& output,
bsz,
hidden_size,
mp_size,
preln,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<float>((float*)input.data_ptr(),
@ -1037,6 +1040,7 @@ void residual_add_bias(at::Tensor& output,
bsz,
hidden_size,
mp_size,
preln,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),

View File

@ -58,6 +58,7 @@ void launch_bias_residual(T* input,
int batch,
int hidden_dim,
int mp_size,
bool preln,
cudaStream_t stream);
template <typename T>

View File

@ -3,21 +3,22 @@ Copyright 2021 The Microsoft DeepSpeed Team
'''
import torch
import os
from torch.nn.modules import Module
import deepspeed.comm as dist
import deepspeed.utils.groups as groups
from torch.nn.modules import Module
from packaging import version as pkg_version
from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject.replace_module import replace_transformer_layer
from ..utils import logger
from ..comm.comm import init_distributed
from ..pipe import PipelineModule
from ..moe.utils import has_moe_layers
from ..moe.layer import MoE
import deepspeed.comm as dist
import deepspeed.utils.groups as groups
DS_INFERENCE_ENABLED = False
@ -88,9 +89,13 @@ class InferenceEngine(Module):
self.ep_group = ep_group
self.expert_mp_group = expert_mp_group
self.enable_cuda_graph = enable_cuda_graph
self.cuda_grah_created = False
self.cuda_graph_created = False
self._init_quantization_setting(quantization_setting)
if enable_cuda_graph:
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"
if self.checkpoint:
self._load_checkpoint(self.checkpoint)
@ -372,7 +377,7 @@ class InferenceEngine(Module):
with torch.cuda.graph(self._cuda_graphs):
self.static_output = self.module(*self.static_inputs, **self.static_kwargs)
self.cuda_grah_created = True
self.cuda_graph_created = True
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
@ -409,7 +414,7 @@ class InferenceEngine(Module):
outputs = self.model_orig_fwd(*inputs, **kwargs)
else:
if self.enable_cuda_graph:
if self.cuda_grah_created:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)

View File

@ -351,8 +351,11 @@ def replace_transformer_layer(orig_layer_impl,
# linear layer is created with [input, output] shape
# transpose it here to reduce inference cost!
def transpose(data):
# temp move to cpu to avoid requiring extra GPU memory during the reshape
data = data.to('cpu')
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
data = data.reshape(data.shape[-1], data.shape[-2])
data.to(torch.cuda.current_device())
return data
if attn_linear_layer:
@ -460,7 +463,7 @@ def replace_transformer_layer(orig_layer_impl,
new_module.norm_b.data = input_nb.to(torch.cuda.current_device())
else:
transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
batch_size=micro_batch_size if micro_batch_size > 0 else 1,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,

View File

@ -471,7 +471,7 @@ class DeepSpeedMLPFunction(Function):
config.pre_layer_norm,
False)
else:
intermediate = mlp_gemm_func(input,
intermediate, residual_add = mlp_gemm_func(input,
residual,
bias,
inter_w,
@ -482,14 +482,16 @@ class DeepSpeedMLPFunction(Function):
config.pre_layer_norm,
config.mlp_after_attn)
output = vector_matmul_func(intermediate, output_w, False)
inference_cuda_module.residual_add(output,
residual,
input,
output_b,
bias if bias is not None else output_b,
config.mp_size,
config.mlp_after_attn,
bias is not None)
inference_cuda_module.residual_add(
output,
residual if config.pre_layer_norm else residual_add,
input,
output_b,
bias if bias is not None else output_b,
config.mp_size,
config.mlp_after_attn,
bias is not None,
config.pre_layer_norm)
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
dist.all_reduce(output, group=mp_group)
return output

View File

@ -0,0 +1,2 @@
lm-eval>=0.2.0
transformers

View File

@ -61,7 +61,8 @@ extras_require = {
'dev': fetch_requirements('requirements/requirements-dev.txt'),
'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'),
'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'),
'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt')
'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),
'inf': fetch_requirements('requirements/requirements-inf.txt')
}
# Add specific cupy version to both onebit extension variants
@ -291,6 +292,7 @@ setup(name='deepspeed',
'bin/ds',
'bin/ds_ssh',
'bin/ds_report',
'bin/dsr',
'bin/ds_elastic'
],
classifiers=[

6
tests/pytest.ini Normal file
View File

@ -0,0 +1,6 @@
[pytest]
addopts = -m "not sequential and not nightly and not inference"
markers =
sequential:Tests that need to be run sequentially
inference:Inference model tests
nightly:Tests that should be run nightly

View File

@ -1,123 +1,323 @@
import os
import time
import torch
import pytest
import itertools
import deepspeed
from deepspeed.git_version_info import torch_info
from collections import defaultdict
from transformers import pipeline
from .common import distributed_test
from packaging import version as pkg_version
from deepspeed.ops.op_builder import OpBuilder
pytest.task_query_dict = {
"fill-mask":
defaultdict(
lambda: "Hello I'm a [MASK] model.",
{"roberta-base": "Hello I'm a <mask> model."},
),
"question-answering":
defaultdict(lambda: {
"question": "What is the greatest?",
"context": "DeepSpeed is the greatest",
}),
"text-classification":
defaultdict(lambda: "DeepSpeed is the greatest"),
"token-classification":
defaultdict(lambda: "My name is jean-baptiste and I live in montreal."),
"text-generation":
defaultdict(lambda: "DeepSpeed is the greatest"),
}
pytest.task_model_dict = {
"fill-mask": {
"bert": "bert-base-cased",
"roberta": "roberta-base"
},
"question-answering": {
"bert": "deepset/minilm-uncased-squad2",
"roberta": "deepset/roberta-base-squad2",
},
"text-classification": {
"bert": "cross-encoder/ms-marco-MiniLM-L-12-v2",
"roberta": "j-hartmann/emotion-english-distilroberta-base",
},
"token-classification": {
"bert": "dslim/bert-base-NER",
"roberta": "Jean-Baptiste/roberta-large-ner-english",
},
"text-generation": {
"gpt2": "distilgpt2",
"gpt_neo": "Norod78/hebrew-bad_wiki-gpt_neo-tiny",
"gptj": "EleutherAI/gpt-j-6B",
},
try:
import lm_eval
import lm_eval.models
import lm_eval.tasks
from lm_eval.evaluator import evaluate
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import HfApi
except ImportError:
pytest.skip("please install w. [inf] extra to run this test",
allow_module_level=True)
rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
pytest.skip("skip inference tests on rocm for now", allow_module_level=True)
_bert_models = [
"bert-base-cased",
"bert-base-uncased",
"bert-large-cased",
"bert-large-uncased",
"bert-base-multilingual-cased",
"bert-base-multilingual-uncased",
"deepset/minilm-uncased-squad2",
"cross-encoder/ms-marco-MiniLM-L-12-v2",
"dslim/bert-base-NER",
"bert-large-uncased-whole-word-masking-finetuned-squad",
"distilbert-base-cased-distilled-squad",
]
_roberta_models = [
"roberta-large",
"roberta-base",
"deepset/roberta-base-squad2",
"j-hartmann/emotion-english-distilroberta-base",
"Jean-Baptiste/roberta-large-ner-english",
]
_gpt_models = [
"gpt2",
"distilgpt2",
"Norod78/hebrew-bad_wiki-gpt_neo-tiny",
"EleutherAI/gpt-j-6B",
]
_all_models = HfApi().list_models()
test_models = set(_bert_models + _roberta_models + _gpt_models)
test_tasks = [
"fill-mask",
"question-answering",
"text-classification",
"token-classification",
"text-generation",
]
pytest.all_models = {
task: [m.modelId for m in _all_models if m.pipeline_tag == task]
for task in test_tasks
}
_model_w_tasks = itertools.product(*[test_models, test_tasks])
def _valid_model_task(model_task):
m, t = model_task
return m in pytest.all_models[t]
pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks))
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks]
"""
These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph
"""
@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
return request.param
@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
return request.param
@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
return request.param
"""
This fixture will validate the configuration
"""
@pytest.fixture()
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
model, task = model_w_task
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
msg = "DS inference injection doesn't work well on older torch versions"
elif model not in pytest.all_models[task]:
msg = f"Not a valid model / task combination: {model} / {task}"
elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
msg = "CUDA not detected, cannot use CUDA Graph"
elif enable_cuda_graph and pkg_version.parse(
torch.__version__) < pkg_version.parse("1.10"):
msg = "CUDA Graph is only available in torch versions >= 1.10"
elif ("gpt-j-6B" in model) and (dtype == torch.float):
msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
else:
msg = ""
return msg
"""
These fixtures can be used to customize the query, inference args, and assert
statement for each combination of model /task
"""
@pytest.fixture
def model(task, model_family):
if model_family not in pytest.task_model_dict[task]:
pytest.skip(f"No models in family {model_family} for task {task}")
return pytest.task_model_dict[task][model_family]
def query(model_w_task):
model, task = model_w_task
if task == "fill-mask":
if "roberta" in model:
return "Hello I'm a <mask> model."
else:
return "Hell I'm a [MASK] model."
elif task == "question-answering":
return {
"question": "What's my name?",
"context": "My name is Clara and I live in Berkeley",
}
elif task == "text-classification":
return "DeepSpeed is the greatest"
elif task == "token-classification":
return "My name is jean-baptiste and I live in montreal."
elif task == "text-generation":
return "DeepSpeed is the greatest"
else:
NotImplementedError(f'query for task "{task}" is not implemented')
@pytest.fixture
def query(task, model):
return pytest.task_query_dict[task][model]
def inf_kwargs(model_w_task):
model, task = model_w_task
if task == "text-generation":
return {"do_sample": False}
else:
return {}
@pytest.fixture
def assert_fn(model_w_task):
model, task = model_w_task
if task == "fill-mask":
return lambda x, y: set(res["token_str"] for res in x) == set(
res["token_str"] for res in y
)
elif task == "question-answering":
return lambda x, y: x["answer"] == y["answer"]
elif task == "text-classification":
return lambda x, y: set(res["label"] for res in x) == set(
res["label"] for res in y
)
elif task == "token-classification":
return lambda x, y: set(ent["word"] for ent in x) == set(
ent["word"] for ent in y
)
elif task == "text-generation":
return lambda x, y: set(res["generated_text"] for res in x) == set(
res["generated_text"] for res in y
)
else:
NotImplementedError(f'assert_fn for task "{task}" is not implemented')
"""
Tests
"""
@pytest.mark.inference
def test_model_task(
model_w_task,
dtype,
enable_cuda_graph,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
model, task = model_w_task
@distributed_test(world_size=[1])
def _go():
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if "gpt-j-6B" in model and dtype == torch.half:
_model = AutoModelForCausalLM.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(model)
_model.half()
pipe = pipeline(
task,
model=_model,
tokenizer=tokenizer,
device=local_rank,
framework="pt",
)
else:
pipe = pipeline(task, model=model, device=local_rank, framework="pt")
if dtype == torch.half:
pipe.model.half()
# Warm-up queries for perf measurement
for i in range(10):
_ = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
start = time.time()
bs_output = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
bs_time = time.time() - start
pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=1,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
enable_cuda_graph=enable_cuda_graph,
)
# Warm-up queries for perf measurement
for i in range(10):
_ = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
start = time.time()
ds_output = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
ds_time = time.time() - start
if task == "text-generation":
bs_output = pipe(query, **inf_kwargs)
# These performance tests are only measuring the time for a single
# inference request, we just want to check that performance isn't terrible
assert ds_time <= (bs_time * 1.1)
assert assert_fn(bs_output, ds_output)
_go()
@pytest.mark.nightly
@pytest.mark.parametrize(
"task",
"model_family, model_name",
(
"fill-mask",
"question-answering",
"text-classification",
"token-classification",
"text-generation",
["gpt2",
"EleutherAI/gpt-neo-2.7B"],
["gpt2",
"EleutherAI/gpt-j-6B"],
["gpt2",
"gpt2-xl"],
),
)
@pytest.mark.parametrize("model_family", ("bert", "roberta", "gpt2", "gpt_neo"))
def test_model_task_inject(task, model, query, dtype=torch.float):
if pkg_version.parse(torch.__version__) <= pkg_version.parse('1.2'):
pytest.skip("DS inference injection doesn't work well on older torch versions")
@pytest.mark.parametrize("task", ["lambada"])
def test_lm_correctness(model_family, model_name, task):
@distributed_test(world_size=[1])
def _go():
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
generator = pipeline(task, model=model, device=local_rank)
local_rank = os.getenv("LOCAL_RANK", "0")
device = torch.device(f"cuda:{local_rank}")
dtype = torch.float
task_dict = lm_eval.tasks.get_task_dict([task])
generator.model = deepspeed.init_inference(
generator.model,
mp_size=world_size,
if 'gpt-j-6B' in model_name:
dtype = torch.half
lm = lm_eval.models.get_model(model_family).create_from_arg_string(
f"pretrained={model_name}",
{"device": "cpu"})
setattr(lm, model_family, getattr(lm, model_family).half().to(device))
lm._device = device
else:
lm = lm_eval.models.get_model(model_family).create_from_arg_string(
f"pretrained={model_name}",
{"device": f"cuda:{local_rank}"})
torch.cuda.synchronize()
start = time.time()
bs_output = evaluate(lm=lm, task_dict=task_dict)
torch.cuda.synchronize()
bs_time = time.time() - start
ds_model = deepspeed.init_inference(
getattr(lm,
model_family),
mp_size=1,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
enable_cuda_graph=False,
)
setattr(lm, model_family, ds_model)
torch.cuda.synchronize()
start = time.time()
ds_output = evaluate(lm=lm, task_dict=task_dict)
torch.cuda.synchronize()
ds_time = time.time() - start
response = generator(query)
_go()
@pytest.mark.parametrize("dtype", [(torch.float), (torch.half)])
def test_gpt2_inject(dtype):
if pkg_version.parse(torch.__version__) <= pkg_version.parse('1.2'):
pytest.skip("DS inference injection doesn't work well on older torch versions")
@distributed_test(world_size=[1])
def _go():
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
generator = pipeline("text-generation", model="gpt2", device=local_rank)
generator.model = deepspeed.init_inference(
generator.model,
mp_size=world_size,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
)
prompt = "DeepSpeed is"
string_1 = generator(prompt, do_sample=False, max_length=128)
string_2 = generator(prompt, do_sample=False, max_length=128)
assert string_1 == string_2
ppl_diff = abs(bs_output["results"][task]["ppl"] -
ds_output["results"][task]["ppl"])
assert ds_time <= bs_time
assert ppl_diff < 0.01
_go()