DeepSpeed Chat (#3186)

Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: yaozhewei <zheweiy@berkeley.edu>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Connor Holmes <connorholmes@microsoft.com>
Co-authored-by: Lok Chand Koppaka <lokoppak@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
This commit is contained in:
Olatunji Ruwase
2023-04-11 14:53:38 -04:00
committed by GitHub
parent cc9dfffbaf
commit 47f9f13bd3
49 changed files with 1445 additions and 405 deletions

View File

@ -26,7 +26,6 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
# put all valid class name <--> class type mapping into class_dict
op_builder_dir = self.op_builder_dir()
op_builder_module = importlib.import_module(op_builder_dir)
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
# avoid self references
if module_name != 'all_ops' and module_name != 'builder':

View File

@ -44,9 +44,9 @@ inline int DS_GET_BLOCKS(const int N)
1);
}
class Context {
class TrainingContext {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
TrainingContext() : _workspace(nullptr), _seed(42), _curr_offset(0)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
@ -57,15 +57,15 @@ public:
}
}
virtual ~Context()
virtual ~TrainingContext()
{
cublasDestroy(_cublasHandle);
cudaFree(_workspace);
}
static Context& Instance()
static TrainingContext& Instance()
{
static Context _ctx;
static TrainingContext _ctx;
return _ctx;
}

View File

@ -39,8 +39,8 @@ public:
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#endif
}

View File

@ -54,8 +54,8 @@ public:
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#endif
}

View File

@ -457,7 +457,7 @@ void launch_sr_fake_quantize_kernel(T* vals,
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
@ -1011,7 +1011,7 @@ void launch_sr_fake_quantize_kernel_asym(T* vals,
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);

View File

@ -278,7 +278,7 @@ void launch_dropout(T* out,
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
if (bwd)
dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
total_count, ratio, vals, out, mask, seed);
@ -625,7 +625,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, bias, out, mask, seed);
@ -847,7 +847,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, input, residual, bias, out, mask, seed);

View File

@ -78,8 +78,8 @@ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
_normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()),
_stream(TrainingContext::Instance().GetCurrentStream()),
_cublasHandle(TrainingContext::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size,
hidden_size,
@ -183,7 +183,7 @@ void BertTransformerLayer<T>::Forward(unsigned bsz,
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
@ -343,7 +343,7 @@ void BertTransformerLayer<T>::Backward(unsigned bsz,
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
@ -609,25 +609,26 @@ int create_transformer_layer(unsigned layer_id,
bool gelu_checkpoint,
bool stochastic_mode)
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
TrainingContext::Instance().SetSeed(seed);
TrainingContext::Instance().TestGemmFP16(
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
auto layer =
std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
TrainingContext::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
s_transformer_layers[layer_id] = layer;
@ -725,7 +726,7 @@ std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
@ -909,7 +910,7 @@ std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);

View File

@ -96,7 +96,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
head_offset,
mask_stride,
mp_size,
Context::Instance().GetCurrentStream(async_op));
InferenceContext::Instance().GetCurrentStream(async_op));
return attn_scores_c;
}
@ -110,18 +110,20 @@ void allocate_workspace(unsigned hidden_dim,
unsigned mp_size = 1,
bool external_cache = false,
unsigned rank = 0,
unsigned max_out_tokens = 1024)
unsigned max_out_tokens = 1024,
unsigned min_out_tokens = 1)
{
Context::Instance().GenWorkSpace(num_layers,
num_heads,
batch_size,
prompt_length,
hidden_dim,
mp_size,
external_cache,
sizeof(T),
rank,
max_out_tokens);
InferenceContext::Instance().GenWorkSpace(num_layers,
num_heads,
batch_size,
prompt_length,
hidden_dim,
mp_size,
external_cache,
sizeof(T),
rank,
max_out_tokens,
min_out_tokens);
}
template <typename T>
@ -132,15 +134,15 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
float alpha = 1;
float gemm_beta = 0.0;
/*
// Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) {
allocate_workspace<T>(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1,
head_size); workspace = (T*)Context::Instance().GetWorkSpace();
allocate_workspace<T>(W.size(1), InferenceContext::Instance().GetMaxTokenLenght(),
Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace();
}
*/
@ -148,7 +150,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_T,
m,
@ -195,8 +197,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
auto mask_stride = get_attn_mask_stride(attn_mask);
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
@ -231,9 +234,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
0,
mask_stride,
1,
Context::Instance().GetCurrentStream(false));
InferenceContext::Instance().GetCurrentStream(false));
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
@ -364,10 +367,11 @@ void attention_unfused(T* prev_key_cont,
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
float alpha = norm_factor * norm_factor / layer_scale;
float gemm_beta = 0.0;
T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace();
T* workspace = (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace();
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
@ -378,7 +382,7 @@ void attention_unfused(T* prev_key_cont,
workspace,
CUBLAS_OP_T,
CUBLAS_OP_N,
Context::Instance().GetMaxTokenLenght() * k,
InferenceContext::Instance().GetMaxTokenLenght() * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
@ -400,7 +404,7 @@ void attention_unfused(T* prev_key_cont,
soft_len,
heads);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
@ -411,7 +415,7 @@ void attention_unfused(T* prev_key_cont,
(T*)output,
CUBLAS_OP_N,
CUBLAS_OP_N,
Context::Instance().GetMaxTokenLenght() * k,
InferenceContext::Instance().GetMaxTokenLenght() * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
@ -422,7 +426,7 @@ void attention_unfused(T* prev_key_cont,
#endif
}
void reset_cache() { Context::Instance().reset_tokens(); }
void reset_cache() { InferenceContext::Instance().reset_tokens(); }
template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
@ -446,8 +450,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bool is_prompt = (seq_len > 1);
if (is_prompt) Context::Instance().reset_tokens(seq_len);
unsigned soft_len = Context::Instance().current_tokens();
if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();
int k = hidden_dim / heads;
auto options = at::TensorOptions()
@ -456,16 +460,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
size_t buf_size = bsz * seq_len * hidden_dim;
auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options);
auto output = torch::from_blob(workspace + 3 * buf_size, {bsz, seq_len, hidden_dim}, options);
auto query_cont = workspace + 8 * buf_size;
size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) +
layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
auto query_cont = workspace + 4 * buf_size;
size_t offset =
10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLenght()) +
layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
unsigned all_tokens = soft_len;
auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1);
size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
T* temp_buf = (T*)output.data_ptr() + at::numel(output);
launch_bias_add_transform_0213<T>((T*)query_cont,
@ -482,9 +487,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
rotary_dim,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
InferenceContext::Instance().GetCurrentStream(),
3,
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetMaxTokenLenght());
if (rotary_dim > 0 && rotate_half)
launch_apply_rotary_pos_emb(query_cont,
kv_cache,
@ -496,8 +501,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLenght());
attention_unfused<T>(workspace + offset,
(T*)query_cont,
@ -522,25 +527,26 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
heads,
seq_len,
output.size(2),
Context::Instance().GetCurrentStream(false),
InferenceContext::Instance().GetCurrentStream(false),
1);
if (layer_id == num_layers - 1) Context::Instance().advance_tokens();
if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset,
{bsz, heads, all_tokens, k},
{hidden_dim * Context::Instance().GetMaxTokenLenght(),
k * Context::Instance().GetMaxTokenLenght(),
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
k * InferenceContext::Instance().GetMaxTokenLenght(),
k,
1},
options);
auto prev_value = torch::from_blob(workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * Context::Instance().GetMaxTokenLenght(),
k * Context::Instance().GetMaxTokenLenght(),
k,
1},
options);
auto prev_value =
torch::from_blob(workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
k * InferenceContext::Instance().GetMaxTokenLenght(),
k,
1},
options);
return {output, prev_key, prev_value};
}
@ -557,7 +563,7 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
@ -583,14 +589,14 @@ at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias)
(const float*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_bias_geglu((__half*)output.data_ptr(),
(const __half*)activation.data_ptr(),
(const __half*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return output;
@ -608,7 +614,7 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
@ -624,7 +630,7 @@ at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(),
hidden_size,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
@ -641,7 +647,7 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor&
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
// InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
@ -659,7 +665,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
@ -668,7 +674,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return output;
@ -689,7 +695,7 @@ void ds_layer_norm_internal(T* workspace,
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
/* Currently only used in unit testing */
@ -714,7 +720,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
@ -725,7 +731,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return output;
@ -755,7 +761,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(),
(float*)res_output.data_ptr(),
@ -767,7 +773,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return {norm_output, res_output};
@ -782,7 +788,7 @@ void quantized_gemm(void* output,
int bsz,
int hidden_size)
{
// T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
// T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
auto options = at::TensorOptions()
.dtype(at::kHalf)
@ -797,11 +803,11 @@ void quantized_gemm(void* output,
weight.size(0),
weight.size(1),
groups,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
weight.size(0),
@ -829,10 +835,11 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& beta,
const float epsilon,
bool add_bias,
bool q_int8)
bool q_int8,
bool transposed_mode)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
workspace += (3 * bsz * input.size(2));
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
@ -843,12 +850,12 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input.size(2),
&alpha,
@ -865,9 +872,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return torch::from_blob(workspace, input.sizes(), input.options());
}
@ -884,11 +891,12 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
bool external_cache,
unsigned mp_size,
unsigned rank,
bool q_int8)
bool q_int8,
bool transposed_mode)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
int out_size = q_int8 ? weight.size(0) : weight.size(1);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
@ -897,8 +905,17 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
.requires_grad(false);
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
auto inp_norm = qkv_unfused_cublas<T>(
output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8);
auto inp_norm = qkv_unfused_cublas<T>(output,
input,
weight,
q_scale,
bias,
gamma,
beta,
epsilon,
add_bias,
q_int8,
transposed_mode);
return {output, inp_norm};
}
@ -926,11 +943,11 @@ void quantized_gemm(at::Tensor& output,
weight.size(1),
groups,
merge_count,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
weight.size(0),
@ -977,7 +994,7 @@ at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return output;
}
@ -988,7 +1005,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
at::Tensor& bias,
bool add_bias,
bool do_flash_attn,
int num_heads)
int num_heads,
bool transposed_mode)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
@ -999,17 +1017,18 @@ at::Tensor ds_linear_layer(at::Tensor& input,
int head_size = input_cont.size(2) / num_heads;
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input_cont.size(2),
&alpha,
@ -1025,9 +1044,9 @@ at::Tensor ds_linear_layer(at::Tensor& input,
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0);
if (do_flash_attn) {
if (add_padding) {
@ -1040,7 +1059,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
3 * bsz * num_heads,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
launch_bias_add_transform_0213<T>(
final_output,
@ -1057,7 +1076,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
InferenceContext::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(final_output,
@ -1082,7 +1101,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
InferenceContext::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(
@ -1100,7 +1119,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
{
int head_size = query.size(3);
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128;
pad_head_seq(workspace,
@ -1110,7 +1129,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
query.size(2),
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
pad_head_seq(key_pad_ptr,
(T*)key.data_ptr(),
query.size(0) * query.size(1),
@ -1118,7 +1137,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
pad_head_seq(value_pad_ptr,
(T*)value.data_ptr(),
query.size(0) * query.size(1),
@ -1126,7 +1145,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return {
at::from_blob(workspace,
{query.size(0), query.size(1), query.size(2), padded_head_size},
@ -1148,7 +1167,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
int key_value_length = add_padding ? 128 : key.size(1);
int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128))
: head_size;
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length;
launch_pad_add_transform_0213(workspace,
@ -1159,7 +1178,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
query.size(1),
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
launch_pad_add_transform_0213(key_pad_ptr,
(T*)key.data_ptr(),
key.size(0),
@ -1168,7 +1187,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
launch_pad_add_transform_0213(value_pad_ptr,
(T*)value.data_ptr(),
value.size(0),
@ -1177,7 +1196,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return {
at::from_blob(
workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()),
@ -1210,7 +1229,7 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return output;
}
@ -1219,7 +1238,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
at::Tensor& weight,
bool async_op,
at::Tensor& q_scale,
bool q_int8)
bool q_int8,
bool transposed_mode)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
@ -1229,7 +1249,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
@ -1242,12 +1262,12 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input.size(2),
&alpha,
@ -1300,11 +1320,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
ActivationFuncType act_func_type)
ActivationFuncType act_func_type,
bool transposed_mode)
{
int bsz = input.size(0) * input.size(1);
T* inp_norm =
(T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output);
T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input) +
torch::numel(output);
T* intermediate = inp_norm + torch::numel(input);
if (mlp_after_attn) {
@ -1317,7 +1338,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
}
@ -1327,12 +1348,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input.size(2),
&alpha,
@ -1349,15 +1370,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
if (act_func_type == ActivationFuncType::GELU) {
launch_bias_gelu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else if (act_func_type == ActivationFuncType::ReLU) {
launch_bias_relu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
if (q_int8) {
@ -1371,14 +1392,14 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight1.size(1),
weight1.size(transposed_mode ? 0 : 1),
bsz,
weight1.size(0),
weight1.size(transposed_mode ? 1 : 0),
&alpha,
&gemm_beta,
(T*)weight1.data_ptr(),
@ -1409,7 +1430,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
int activation_type)
int activation_type,
bool transposed_mode)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
@ -1417,10 +1439,11 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size},
options);
int out_size = (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1);
auto output =
at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size},
options);
int bsz = input.size(0) * input.size(1);
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
@ -1439,7 +1462,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
q_scale,
q_scale1,
q_int8,
act_func_type);
act_func_type,
transposed_mode);
return {output, res_add};
}
@ -1475,7 +1499,7 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return {output, residual_add};
}
@ -1490,7 +1514,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
const float epsilon,
bool preLayerNorm,
bool q_int8,
bool async_op)
bool async_op,
bool transposed_mode)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
@ -1498,9 +1523,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);
int intm_dim = q_int8 ? weight.size(0) : weight.size(1);
int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
// auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() +
// torch::numel(input),
// {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
@ -1519,10 +1545,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
bsz,
input.size(2));
} else {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
intm_dim,
bsz,
@ -1542,9 +1568,9 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
(T*)bias.data_ptr(),
intm_dim,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
int out_size = (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1);
auto output = at::empty({input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
@ -1555,8 +1581,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
bsz,
input.size(2));
} else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
out_size,
bsz,
@ -1572,8 +1598,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
// cudaEventRecord(InferenceContext::Instance().GetCompEvent(2),
// InferenceContext::Instance().GetCurrentStream(true));
return output;
}
@ -1600,7 +1626,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
hidden_size,
mp_size,
preln,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
else
launch_gptj_residual_add<T>(
static_cast<T*>(residual.data_ptr()),
@ -1611,7 +1637,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return residual;
}
@ -1641,8 +1667,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLenght());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
@ -1654,8 +1680,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLenght());
return {query_cont, key_cont};
}
@ -1684,7 +1710,7 @@ at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return output;
}
@ -1693,7 +1719,7 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
{
int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2);
Context::Instance().SynchComm();
InferenceContext::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(),
@ -1712,6 +1738,10 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
return output;
}
void ds_release_workspace() { InferenceContext::Instance().release_workspace(); }
bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
@ -1791,4 +1821,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
&allocate_workspace<__half>,
"DeepSpeed memory allocation for GPT inference with fp16 (CUDA)");
m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks");
m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace");
m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace");
}

View File

@ -46,17 +46,20 @@ inline int DS_GET_BLOCKS(const int N)
1);
}
class Context {
class InferenceContext {
public:
Context()
InferenceContext()
: _workspace(nullptr),
_seed(42),
_curr_offset(0),
_stream(0),
_free_memory_size(0),
_num_tokens(1),
_attention_unfused_workspace_offset(0)
_attention_unfused_workspace_offset(0),
_workSpaceSize(0)
{
_workSpaceSize = 0;
_workspace = 0;
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
auto message = std::string("Fail to create cublas handle.");
std::cerr << message << std::endl;
@ -71,7 +74,7 @@ public:
cudaEventCreate(&_comm_event);
}
virtual ~Context()
virtual ~InferenceContext()
{
cublasDestroy(_cublasHandle);
cudaFree(_workspace);
@ -81,9 +84,9 @@ public:
cudaEventDestroy(_comm_event);
}
static Context& Instance()
static InferenceContext& Instance()
{
static Context _ctx;
static InferenceContext _ctx;
return _ctx;
}
@ -96,7 +99,8 @@ public:
const bool& external_cache,
const size_t& elem_size,
const unsigned& rank,
unsigned max_out_tokens)
unsigned max_out_tokens,
unsigned min_out_tokens)
{
size_t total_size;
if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }
@ -107,9 +111,9 @@ public:
const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128);
const int effective_head_size = (head_size > 128) ? head_size : padded_head_size;
size_t activation_size = 16 * (num_heads * effective_head_size) * batch_size;
size_t activation_size = 10 * (num_heads * effective_head_size) * batch_size;
// Other sequence length dimension is added when the final workSpaceSize is calculated
size_t temp_size = batch_size * num_heads * max_out_tokens * 2;
size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens;
size_t cache_size =
num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2;
size_t minimal_requirements =
@ -129,18 +133,16 @@ public:
: (activation_size + temp_size + cache_size))) *
_max_seq_len * elem_size;
temp_size *= _max_seq_len * elem_size;
if (rank == 0 && !_workspace)
if (_max_seq_len < min_out_tokens) {
printf(
"------------------------------------------------------\n"
"Free memory : %f (GigaBytes) \n"
"Total memory: %f (GigaBytes) \n"
"Requested memory: %f (GigaBytes) \n"
"Setting maximum total tokens (input + output) to %lu \n"
"------------------------------------------------------\n",
(float)_free_memory_size / GIGABYTE,
(float)total_size / GIGABYTE,
(float)workSpaceSize / GIGABYTE,
_max_seq_len);
"Allocatable workspace available (%d tokens) is less than minimum requested "
"workspace (%d tokens)\n",
_max_seq_len,
min_out_tokens);
throw std::runtime_error("Workspace can't be allocated, not enough memory");
}
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, workSpaceSize);
@ -148,6 +150,20 @@ public:
cudaFree(_workspace);
cudaMalloc(&_workspace, workSpaceSize);
}
if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize))
printf(
"------------------------------------------------------\n"
"Free memory : %f (GigaBytes) \n"
"Total memory: %f (GigaBytes) \n"
"Requested memory: %f (GigaBytes) \n"
"Setting maximum total tokens (input + output) to %lu \n"
"WorkSpace: %p \n"
"------------------------------------------------------\n",
(float)_free_memory_size / GIGABYTE,
(float)total_size / GIGABYTE,
(float)workSpaceSize / GIGABYTE,
_max_seq_len,
_workspace);
if (!_workspace) {
printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
@ -203,6 +219,17 @@ public:
return stream;
}
void release_workspace()
{
cudaFree(_workspace);
_workspace = nullptr;
}
bool retake_workspace()
{
if (_workspace != nullptr || _workSpaceSize == 0) return true;
cudaMalloc(&_workspace, _workSpaceSize);
return _workspace != nullptr;
}
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)

View File

@ -17,6 +17,7 @@ from . import module_inject
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
from .inference.config import DeepSpeedInferenceConfig
@ -26,7 +27,7 @@ from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer
from .utils import log_dist, OnDevice
from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed
from .runtime import zero
@ -118,31 +119,66 @@ def initialize(args=None,
assert model is not None, "deepspeed.initialize requires a model"
# Set config using config_params for backwards compat
if config is None and config_params is not None:
config = config_params
# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
if hasattr(args, "deepspeed_config"):
assert (args.deepspeed_config is
None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config
args.deepscale_config = None
# Check that we have only one config passed
if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None:
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
config = args.deepspeed_config
assert config != None, "DeepSpeed requires --deepspeed_config to specify configuration file"
if not isinstance(model, PipelineModule):
engine = DeepSpeedEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
config_params=config_params)
config_class = DeepSpeedConfig(config, mpu)
if config_class.hybrid_engine.enabled:
engine = DeepSpeedHybridEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
config_class=config_class)
else:
engine = DeepSpeedEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
config_class=config_class)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
mpu = model.mpu()
config_class = DeepSpeedConfig(config, mpu)
engine = PipelineEngine(args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
mpu=model.mpu(),
mpu=mpu,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
config_params=config_params)
config_class=config_class)
return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
return tuple(return_items)

View File

@ -197,6 +197,11 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
This can be passed through the json config too.
"""
set_empty_params: bool = False
"""
specifying whether the inference-module is created with empty or real Tensor
"""
save_mp_checkpoint_path: str = None
"""
The path for which we want to save the loaded model with a checkpoint. This
@ -247,6 +252,16 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
to the required token-length required for your use-case.
"""
min_out_tokens: int = Field(1, alias="min_tokens")
"""
This argument communicates to the runtime the minimum number of tokens you
expect you will need to generate. This will cause the runtime to error
if it unable to provide this and provide context on the memory pressure
rather than seg-faulting or providing corrupted output.
"""
transposed_mode: bool = Field(False, alias="transposed_mode")
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
"""
Desired model parallel size, default is 1 meaning no model parallelism.

View File

@ -65,13 +65,18 @@ class DeepSpeedTransformerInference(nn.Module):
mlp_extra_grouping)
device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
if self.config.set_empty_params:
self.norm_w = None
self.norm_b = None
else:
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.layer_past = None
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
inference_cuda_module.allocate_workspace_fp16
self._alloc_workspace = True
@classmethod
def reset_cache(cls):
@ -110,12 +115,14 @@ class DeepSpeedTransformerInference(nn.Module):
input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens)
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False
get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask

View File

@ -34,14 +34,11 @@ class BaseTransformerContainer(ABC):
self.hidden_size = None
self.num_attention_heads = None
self.mp_size = self.config.tensor_parallel.tp_size
self.pre_layer_norm = self.policy.pre_attn_norm
self.pre_layer_norm = self.model_config.do_layer_norm_before if \
hasattr(self.model_config, 'do_layer_norm_before') else self.policy.pre_attn_norm
self.fp16 = False
self.attn_linear_layer = self.policy.linear_layer
self.mlp_linear_layer = self.policy.linear_layer
self.layer_norm_eps = self.model_config.layer_norm_eps if \
hasattr(self.model_config, 'layer_norm_eps') else (self.model_config.layer_norm_epsilon if \
hasattr(self.model_config, 'layer_norm_epsilon') else self.model_config.layernorm_epsilon if \
hasattr(self.model_config, 'layernorm_epsilon') else 1.0e-12)
self.return_tuple = self.config.return_tuple
self.triangular_masking = True
self.local_attention = ((self.model_config.attention_layers[self.layer_id] == "local") if hasattr(
@ -51,6 +48,7 @@ class BaseTransformerContainer(ABC):
self.training_mp_size = self.config.training_mp_size
self.bigscience_bloom = False
self.max_out_tokens = self.config.max_out_tokens
self.min_out_tokens = self.config.min_out_tokens
self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False)
self.use_mup = self.policy.use_mup
self.return_single_tuple = False
@ -75,6 +73,8 @@ class BaseTransformerContainer(ABC):
self.input_nw = None
self.input_nb = None
self.mp_group = None
def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
assert self.num_attention_heads % self.mp_size == 0,\
@ -84,11 +84,11 @@ class BaseTransformerContainer(ABC):
self.ds_model_config = DeepSpeedInferenceConfig(
hidden_size=self.hidden_size,
heads=self.num_attention_heads,
layer_norm_eps=self.layer_norm_eps,
layer_norm_eps=self.layernorm_epsilon,
fp16=self.fp16,
pre_layer_norm=self.pre_layer_norm,
mp_size=self.mp_size,
q_int8=self.quantize,
q_int8=self.quantize if hasattr(self, 'quantize') else False,
return_tuple=self.return_tuple,
triangular_masking=self.triangular_masking,
local_attention=self.local_attention,
@ -99,18 +99,24 @@ class BaseTransformerContainer(ABC):
training_mp_size=self.training_mp_size,
bigscience_bloom=self.bigscience_bloom,
max_out_tokens=self.max_out_tokens,
min_out_tokens=self.min_out_tokens,
scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx,
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
)
set_empty_params=self.config.set_empty_params,
transposed_mode=self.config.transposed_mode)
return self.ds_model_config
def initialize_tensors(self):
def initialize_tensors(self, enable_training=False):
# Set the tensors from policy (user module) to container (DS module)
self.set_attention(*self.policy.attention())
self.set_attention(*self.policy.attention(enable_training=enable_training))
self.set_mlp(*self.policy.mlp())
self.set_layernorm(*self.policy.layernorm())
self.set_lora_params(self.policy.get_lora_params())
self.q_k_v = self.policy.get_q_k_v()
if self.q_k_v is not None:
self.set_q_k_v(*self.q_k_v)
def convert_to_required_dtype(self, dtype):
# Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
@ -138,9 +144,10 @@ class BaseTransformerContainer(ABC):
self.quantize = quantize
self.quantizer = quantizer
def set_hidden_heads(self, hidden_size, num_attention_heads):
def set_hidden_heads(self, hidden_size, num_attention_heads, epsilon):
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.layernorm_epsilon = epsilon
def set_attention(self, qkvw, qkvb, dense_w, dense_b):
self.qkvw = qkvw
@ -148,6 +155,17 @@ class BaseTransformerContainer(ABC):
self.dense_w = dense_w
self.dense_b = dense_b
def set_lora_params(self, lora_params):
self.lora_params = lora_params
def set_q_k_v(self, qw, qb, kw, kb, vw, vb):
self.qw = qw
self.qb = qb
self.kw = kw
self.kb = kb
self.vw = vw
self.vb = vb
def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
self._h4h_w = _h4h_w
self._h4h_b = _h4h_b
@ -175,33 +193,148 @@ class BaseTransformerContainer(ABC):
self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w)
self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w)
def apply_tensor_parallelism(self, mp_replace):
def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None):
reversed_dim = False
if mp_replace is None:
from deepspeed.module_inject import ReplaceWithTensorSlicing
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=tp_size, out_dim=0, in_dim=1)
reversed_dim = True
# setup the new Attention module
self.attention_qkv_mp(mp_replace)
self.attention_o_mp(mp_replace)
if self.module.attention.attn_qkvw is None:
self.attention_q_k_v_mp(mp_replace, reversed_dim=reversed_dim)
else:
self.attention_qkv_mp(mp_replace, reversed_dim=reversed_dim)
self.attention_o_mp(mp_replace, reversed_dim=reversed_dim)
# setup the new MLP module
self.mlp_inter_mp(mp_replace)
self.mlp_output_mp(mp_replace)
self.mlp_inter_mp(mp_replace, reversed_dim=reversed_dim)
self.mlp_output_mp(mp_replace, reversed_dim=reversed_dim)
# Apply weight quantization
self.apply_weight_quantization()
#self.apply_weight_quantization()
def attention_qkv_mp(self, mp_replace):
self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw, self.qkvw)
self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb, self.qkvb)
def attention_qkv_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw,
self.qkvw,
int8=reversed_dim)
self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb,
self.qkvb,
int8=reversed_dim)
def attention_o_mp(self, mp_replace):
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, self.dense_w)
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob, self.dense_b)
def attention_q_k_v_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qw = mp_replace.copy(self.module.attention.attn_qw[:self.qw.shape[0] //
mp_replace.mp_size],
self.qw,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_kw = mp_replace.copy(self.module.attention.attn_kw[:self.qw.shape[0] //
mp_replace.mp_size],
self.kw,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_vw = mp_replace.copy(self.module.attention.attn_vw[:self.qw.shape[0] //
mp_replace.mp_size],
self.vw,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_qb = mp_replace.copy(self.module.attention.attn_qb[:self.qw.shape[0] //
mp_replace.mp_size],
self.qb,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_kb = mp_replace.copy(self.module.attention.attn_kb[:self.qw.shape[0] //
mp_replace.mp_size],
self.kb,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_vb = mp_replace.copy(self.module.attention.attn_vb[:self.qw.shape[0] //
mp_replace.mp_size],
self.vb,
int8=reversed_dim,
allocat_tensor=reversed_dim)
def mlp_inter_mp(self, mp_replace):
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w)
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b)
def attention_o_mp(self, mp_replace, reversed_dim=False):
if reversed_dim:
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow[:, :self.dense_w.shape[1] //
mp_replace.mp_size],
self.dense_w,
int8=reversed_dim,
allocat_tensor=reversed_dim)
else:
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow,
self.dense_w,
int8=reversed_dim)
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob,
self.dense_b,
int8=reversed_dim,
allocat_tensor=reversed_dim)
def mlp_output_mp(self, mp_replace):
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w)
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, self._4hh_b)
def mlp_inter_mp(self, mp_replace, reversed_dim=False):
if reversed_dim:
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w[:self._h4h_w.shape[0] //
mp_replace.mp_size],
self._h4h_w,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b[:self._h4h_w.shape[0] //
mp_replace.mp_size],
self._h4h_b,
int8=reversed_dim,
allocat_tensor=reversed_dim)
else:
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w, int8=reversed_dim)
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b, int8=reversed_dim)
def mlp_output_mp(self, mp_replace, reversed_dim=False):
if reversed_dim:
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w[:, :self._4hh_w.shape[1] //
mp_replace.mp_size],
self._4hh_w,
int8=reversed_dim,
allocat_tensor=reversed_dim)
else:
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w, int8=reversed_dim)
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b,
self._4hh_b,
int8=reversed_dim,
allocat_tensor=reversed_dim)
def release_qkv(self):
del self.module.attention.attn_qkvw
del self.module.attention.attn_qkvb
self.module.attention.attn_qkvw = None
self.module.attention.attn_qkvb = None
qkv_data = [self.module.attention.attn_qw.data, \
self.module.attention.attn_qb.data, \
self.module.attention.attn_kw.data, \
self.module.attention.attn_kb.data, \
self.module.attention.attn_vw.data, \
self.module.attention.attn_vb.data]
for data in qkv_data:
del data
self.module.attention.attn_qw = self.qw
self.module.attention.attn_qb = self.qb
self.module.attention.attn_kw = self.kw
self.module.attention.attn_kb = self.kb
self.module.attention.attn_vw = self.vw
self.module.attention.attn_vb = self.vb
def release_memory(self):
self.release_qkv()
del self.module.attention.attn_ow
del self.module.attention.attn_ob
self.module.attention.attn_ow = self.dense_w
self.module.attention.attn_ob = self.dense_b
del self.module.mlp.inter_w
del self.module.mlp.inter_b
del self.module.mlp.output_w
del self.module.mlp.output_b
self.module.mlp.inter_w = self._h4h_w
self.module.mlp.inter_b = self._h4h_b
self.module.mlp.output_w = self._4hh_w
self.module.mlp.output_b = self._4hh_b
def copy_data_to_new_module(self):
if self.attn_nw is None:
@ -234,3 +367,106 @@ class BaseTransformerContainer(ABC):
data = data.reshape(data.shape[-1], data.shape[-2])
data.to(get_accelerator().current_device_name())
return data
def reset_qkv_experimental(self):
if self.module.attention.attn_qkvw is None:
self.module.attention.attn_qkvw = torch.empty(self.qw.shape[0] * 3,
self.qw.shape[0],
dtype=self.qw.dtype,
device=self.qw.device)
self.module.attention.attn_qkvb = torch.empty(self.qw.shape[0] * 3,
dtype=self.qw.dtype,
device=self.qw.device)
self.module.attention.attn_qkvw.data[:self.qw.shape[0]] = self.qw.data
self.module.attention.attn_qkvb.data[:self.qw.shape[0]] = self.qb.data
self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
qkv_data = [self.qw.data, \
self.qb.data, \
self.kw.data, \
self.kb.data, \
self.vw.data, \
self.vb.data]
self.qw.data = self.module.attention.attn_qkvw.data[:self.qw.shape[0]]
self.qb.data = self.module.attention.attn_qkvb.data[:self.qw.shape[0]]
self.kw.data = self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.kb.data = self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.vw.data = self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:]
self.vb.data = self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:]
for data in qkv_data:
del data
def reset_qkv(self):
self.qkvw.data[:self.qw.shape[0]] = self.qw.data
self.qkvb.data[:self.qw.shape[0]] = self.qb.data
self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
self.qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
self.qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
qkv_data = [self.qw.data, \
self.qb.data, \
self.kw.data, \
self.kb.data, \
self.vw.data, \
self.vb.data]
self.qw.data = self.qkvw.data[:self.qw.shape[0]]
self.qb.data = self.qkvb.data[:self.qw.shape[0]]
self.kw.data = self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.kb.data = self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.vw.data = self.qkvw.data[2 * self.qw.shape[0]:]
self.vb.data = self.qkvb.data[2 * self.qw.shape[0]:]
for data in qkv_data:
del data
def set_params_wo_copy(self, Z3_enabled=False):
self.module.mlp.attn_nw = self.attn_nw
self.module.mlp.attn_nb = self.attn_nb
self.module.norm_w = self.input_nw
self.module.norm_b = self.input_nb
self.module.mlp.inter_w = self._h4h_w
self.module.mlp.inter_b = self._h4h_b
self.module.mlp.output_w = self._4hh_w
self.module.mlp.output_b = self._4hh_b
self.module.attention.attn_ow = self.dense_w
self.module.attention.attn_ob = self.dense_b
if not Z3_enabled or self.q_k_v is None:
self.module.attention.attn_qkvw = self.qkvw
self.module.attention.attn_qkvb = self.qkvb
if self.q_k_v is not None:
if Z3_enabled:
self.module.attention.attn_qw = self.qw
self.module.attention.attn_qb = self.qb
self.module.attention.attn_kw = self.kw
self.module.attention.attn_kb = self.kb
self.module.attention.attn_vw = self.vw
self.module.attention.attn_vb = self.vb
else:
self.qw.data = self.qkvw[:self.qw.shape[0], :]
self.qb.data = self.qkvb[:self.qw.shape[0]]
self.kw.data = self.qkvw[self.qw.shape[0]:2 * self.qw.shape[0], :]
self.kb.data = self.qkvb[self.qw.shape[0]:2 * self.qw.shape[0]]
self.vw.data = self.qkvw[self.qw.shape[0] * 2:, :]
self.vb.data = self.qkvb[self.qw.shape[0] * 2:]
def get_lora_params(self):
return self.lora_params
def get_all_params(self):
if self.q_k_v is not None:
return [
self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w,
self._4hh_b, self.qw, self.qb, self.kw, self.kb, self.vw, self.vb, self.dense_w, self.dense_b
]
else:
return [
self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w,
self._4hh_b, self.qkvw, self.qkvb, self.dense_w, self.dense_b
]

View File

@ -44,10 +44,18 @@ class HFBertLayerPolicy(TransformerPolicy):
HFBertLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):
if self.pre_attn_norm:
attention_layernorm = self.client_module.PostAttentionLayerNorm
else:
attention_layernorm = self.client_module.attention.output.LayerNorm
return self.client_module.attention.self.query.weight.shape[1], \
self.client_module.attention.self.num_attention_heads
self.client_module.attention.self.num_attention_heads, \
attention_layernorm.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
qw = self.client_module.attention.self.query.weight
qb = self.client_module.attention.self.query.bias
kw = self.client_module.attention.self.key.weight
@ -55,8 +63,8 @@ class HFBertLayerPolicy(TransformerPolicy):
vw = self.client_module.attention.self.value.weight
vb = self.client_module.attention.self.value.bias
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
return qkvw, \
qkvb, \
@ -84,3 +92,6 @@ class HFBertLayerPolicy(TransformerPolicy):
attention_layernorm.bias, \
transformer_layernorm.weight, \
transformer_layernorm.bias
def get_lora_params(self):
return []

View File

@ -28,7 +28,7 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
self.module.config.scale_attention = self.scale_attention
return self.module
def attention_qkv_mp(self, mp_replace):
def attention_qkv_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)
@ -84,9 +84,13 @@ class BLOOMLayerPolicy(TransformerPolicy):
def get_hidden_heads(self):
return self.client_module.self_attention.hidden_size, \
self.client_module.self_attention.num_heads
self.client_module.self_attention.num_heads, \
self.client_module.input_layernorm.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
return self.client_module.self_attention.query_key_value.weight, \
self.client_module.self_attention.query_key_value.bias, \
self.client_module.self_attention.dense.weight, \
@ -103,3 +107,6 @@ class BLOOMLayerPolicy(TransformerPolicy):
self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
def get_lora_params(self):
return []

View File

@ -40,7 +40,11 @@ class HFCLIPLayerPolicy(TransformerPolicy):
def get_hidden_heads(self):
return self.client_module.self_attn.q_proj.weight.shape[1], \
self.client_module.self_attn.num_heads
self.client_module.self_attn.num_heads, \
self.client_module.layer_norm1.eps
def get_q_k_v(self):
return None
def attention(self):
qw = self.client_module.self_attn.q_proj.weight
@ -69,3 +73,6 @@ class HFCLIPLayerPolicy(TransformerPolicy):
self.client_module.layer_norm2.bias, \
self.client_module.layer_norm1.weight, \
self.client_module.layer_norm1.bias
def get_lora_params(self):
return []

View File

@ -45,9 +45,13 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
def get_hidden_heads(self):
return self.client_module.attention.q_lin.weight.shape[1], \
self.client_module.attention.n_heads
self.client_module.attention.n_heads, \
self.client_module.sa_layer_norm.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
qw = self.client_module.attention.q_lin.weight
qb = self.client_module.attention.q_lin.bias
kw = self.client_module.attention.k_lin.weight
@ -55,8 +59,8 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
vw = self.client_module.attention.v_lin.weight
vb = self.client_module.attention.v_lin.bias
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0))
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0))
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
return qkvw, \
qkvb, \
@ -77,3 +81,6 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
attention_layernorm.bias, \
transformer_layernorm.weight, \
transformer_layernorm.bias
def get_lora_params(self):
return []

View File

@ -13,18 +13,18 @@ class MetaTensorContainer(ABC):
self.is_meta = False
self.ckpt_load_enabled = True
def initialize_tensors(self):
super().initialize_tensors()
def initialize_tensors(self, enable_training=False):
super().initialize_tensors(enable_training=enable_training)
self.is_meta = self.qkvw.is_meta
def apply_tensor_parallelism(self, mp_replace):
def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None):
if self.is_meta:
if self.qkvb is None:
self.module.attention.attn_qkvb = None
if self.dense_b is None:
self.module.attention.attn_ob = None
else:
super().apply_tensor_parallelism(mp_replace)
super().apply_tensor_parallelism(mp_replace, mp_group, tp_size)
def copy_data_to_new_module(self):
if self.is_meta:

View File

@ -37,9 +37,13 @@ class HFGPT2LayerPolicy(TransformerPolicy):
def get_hidden_heads(self):
return self.client_module.attn.embed_dim, \
self.client_module.attn.num_heads
self.client_module.attn.num_heads, \
self.client_module.ln_1.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
return self.client_module.attn.c_attn.weight, \
self.client_module.attn.c_attn.bias, \
self.client_module.attn.c_proj.weight, \
@ -56,3 +60,6 @@ class HFGPT2LayerPolicy(TransformerPolicy):
self.client_module.ln_2.bias, \
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
def get_lora_params(self):
return []

View File

@ -71,14 +71,18 @@ class HFGPTJLayerPolicy(TransformerPolicy):
def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.shape[1], \
self.client_module.attn.num_attention_heads
self.client_module.attn.num_attention_heads, \
self.client_module.ln_1.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
return qkvw, \
None, \
@ -96,3 +100,6 @@ class HFGPTJLayerPolicy(TransformerPolicy):
None, \
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
def get_lora_params(self):
return []

View File

@ -73,14 +73,18 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
def get_hidden_heads(self):
return self.client_module.attn.attention.q_proj.weight.shape[1], \
self.client_module.attn.attention.num_heads
self.client_module.attn.attention.num_heads, \
self.client_module.ln_1.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
qw = self.client_module.attn.attention.q_proj.weight
kw = self.client_module.attn.attention.k_proj.weight
vw = self.client_module.attn.attention.v_proj.weight
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
return qkvw, \
None, \
@ -98,3 +102,6 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
self.client_module.ln_2.bias, \
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
def get_lora_params(self):
return []

View File

@ -92,9 +92,13 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
attention = self.client_module.self_attention
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
if GPTNEOXLayerPolicy.version == 0:
attention = self.client_module.attention
else:
@ -116,3 +120,6 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
def get_lora_params(self):
return []

View File

@ -56,9 +56,13 @@ class MegatronLayerPolicy(TransformerPolicy):
def get_hidden_heads(self):
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps
def attention(self):
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
if self.inference:
if MegatronLayerPolicy.version == 0:
attention = self.client_module.attention
@ -106,3 +110,6 @@ class MegatronLayerPolicy(TransformerPolicy):
self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
def get_lora_params(self):
return []

View File

@ -12,6 +12,7 @@ from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
from ..policy import maybe_copy_qkv
from ..policy import maybe_get_lora
from deepspeed.utils.types import ActivationFuncType
@ -77,20 +78,26 @@ class HFOPTLayerPolicy(TransformerPolicy):
pre_attn_norm=True,
use_load_prefix=use_load_prefix)
self.client_module = client_module
try:
import transformers
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
if isinstance(TransformerPolicy.hf_model_config, transformers.models.opt.configuration_opt.OPTConfig):
self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before
except:
HFOPTLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):
return self.client_module.self_attn.embed_dim, \
self.client_module.self_attn.num_heads
self.client_module.self_attn.num_heads, \
self.client_module.self_attn_layer_norm.eps
def attention(self):
def get_q_k_v(self):
return self.client_module.self_attn.q_proj.weight, \
self.client_module.self_attn.q_proj.bias, \
self.client_module.self_attn.k_proj.weight, \
self.client_module.self_attn.k_proj.bias, \
self.client_module.self_attn.v_proj.weight, \
self.client_module.self_attn.v_proj.bias
def attention(self, enable_training=False):
qw = self.client_module.self_attn.q_proj.weight
qb = self.client_module.self_attn.q_proj.bias
@ -100,9 +107,8 @@ class HFOPTLayerPolicy(TransformerPolicy):
vw = self.client_module.self_attn.v_proj.weight
vb = self.client_module.self_attn.v_proj.bias
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
return qkvw, \
qkvb, \
self.client_module.self_attn.out_proj.weight, \
@ -119,3 +125,16 @@ class HFOPTLayerPolicy(TransformerPolicy):
self.client_module.final_layer_norm.bias, \
self.client_module.self_attn_layer_norm.weight, \
self.client_module.self_attn_layer_norm.bias
def get_lora_params(self):
all_lora_params = []
for p in [
self.client_module.fc1, \
self.client_module.fc2, \
self.client_module.self_attn.q_proj, \
self.client_module.self_attn.k_proj, \
self.client_module.self_attn.v_proj, \
self.client_module.self_attn.out_proj, \
]:
all_lora_params.append(maybe_get_lora(p))
return all_lora_params

View File

@ -55,22 +55,34 @@ class LinearLayer(nn.Module):
class Normalize(nn.Module):
def __init__(self, dim, dtype=torch.float, eps=1e-5):
def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None):
super(Normalize, self).__init__()
self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
self.weight = self.norm.weight
self.bias = self.norm.bias
if weight is not None:
self.weight = weight
self.bias = bias
else:
self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
self.weight = self.norm.weight
self.bias = self.norm.bias
self.eps = eps
def forward(self, input):
return self.norm(input)
return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps)
class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape, dtype=torch.half):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(EmbeddingLayer, self).__init__()
self.weight = Parameter(
torch.empty(weight_shape[0], weight_shape[1], dtype=dtype, device=get_accelerator().current_device_name()))
if weight is None:
self.weight = Parameter(
torch.empty(weight_shape[0],
weight_shape[1],
dtype=dtype,
device=get_accelerator().current_device_name()))
else:
self.weight = weight
def forward(self, input):
return F.embedding(input, self.weight)
@ -81,11 +93,11 @@ class OPTEmbedding(EmbeddingLayer):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape):
def __init__(self, weight_shape=None, weight=None, bias=None):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape)
super().__init__(weight_shape, weight=weight)
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""

View File

@ -72,7 +72,7 @@ class TransformerPolicy(DSPolicy):
self.split_qkv = split_qkv
@abstractmethod
def attention(self):
def attention(self, enable_training=False):
"""
Returns attention qkv and dense parameters
weight: (3*hidden, hidden) and (hidden, hidden)
@ -80,6 +80,13 @@ class TransformerPolicy(DSPolicy):
"""
raise NotImplementedError
@abstractmethod
def get_q_k_v(self):
"""
return all q,k,v parameters without merging them together
"""
raise NotImplementedError
@abstractmethod
def get_hidden_heads(self):
"""
@ -105,6 +112,14 @@ class TransformerPolicy(DSPolicy):
"""
raise NotImplementedError
@abstractmethod
def get_lora_params(self):
"""
Returns lora parameters used in transformer layer
"""
raise NotImplementedError
# TODO (lekurile): This function exists in base container as well, consolidate as some point
def transpose(data):
@ -189,3 +204,19 @@ def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names
dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \
transpose(qkv_data)), int8=weight_quantizer.q_int8)
setattr(module, dst_name, dst)
def pack_lora_weights(p):
return [
p.lora_right_weight, \
p.lora_left_weight, \
p.lora_scaling
]
def maybe_get_lora(p):
if hasattr(p, 'lora_right_weight'):
lora_param = pack_lora_weights(p)
else:
lora_param = []
return lora_param

View File

@ -87,10 +87,12 @@ class ReplaceWithTensorSlicing:
dst.scale = src.scale
return dst
def copy(self, dst, src, int8=False):
def copy(self, dst, src, int8=False, allocat_tensor=False):
if src is None:
return src
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
if allocat_tensor:
dst = torch.empty_like(dst)
outer_dim = 0 if int8 else 1
inner_dim = 1 if int8 else 0
src_shape = src.shape
@ -102,21 +104,21 @@ class ReplaceWithTensorSlicing:
else:
if src_shape[inner_dim] != dst_shape[self.in_dim]:
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
weight_split = torch.split(src, dst_shape[self.in_dim], dim=inner_dim)[self.gpu_index].contiguous()
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
else:
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
weight_split = torch.split(src.data, dst_shape[self.out_dim],
dim=outer_dim)[self.gpu_index].contiguous()
dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(weight_split.shape)
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
else:
if src_shape[0] == dst_shape[0]:
dst.data.copy_(src)
dst = src
else:
bias_split = torch.split(src.data, dst_shape[-1])[self.gpu_index].contiguous()
dst.data.copy_(bias_split)
dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst

View File

@ -69,10 +69,13 @@ class DeepSpeedInferenceConfig(TransformerConfig):
training_mp_size=1,
bigscience_bloom=False,
max_out_tokens=1024,
min_out_tokens=1,
enable_qkv_quantization=False,
use_mup=False,
scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False):
return_single_tuple=False,
set_empty_params=False,
transposed_mode=False):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
@ -97,10 +100,13 @@ class DeepSpeedInferenceConfig(TransformerConfig):
self.training_mp_size = training_mp_size
self.bigscience_bloom = bigscience_bloom
self.max_out_tokens = max_out_tokens
self.min_out_tokens = min_out_tokens
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.enable_qkv_quantization = enable_qkv_quantization
self.use_mup = use_mup
self.return_single_tuple = return_single_tuple
self.set_empty_params = set_empty_params
self.transposed_mode = transposed_mode
@classmethod
def from_dict(cls, json_object):

View File

@ -15,6 +15,7 @@ minus_inf = -10000.0
class DeepSpeedSelfAttention(nn.Module):
num_layers = 0
_qkv_buffers = []
def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1):
super(DeepSpeedSelfAttention, self).__init__()
@ -24,23 +25,35 @@ class DeepSpeedSelfAttention(nn.Module):
self.config.layer_id = DeepSpeedSelfAttention.num_layers
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
qkv_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
requires_grad=False)
out_size_per_partition = self.config.hidden_size // self.config.mp_size
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
if self.config.set_empty_params:
self.attn_qw = None
self.attn_qb = None
self.attn_kw = None
self.attn_kb = None
self.attn_vw = None
self.attn_vb = None
self.attn_qkvw = None
self.attn_qkvb = None
self.attn_ow = None
self.attn_ob = None
else:
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
qkv_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
requires_grad=False)
out_size_per_partition = self.config.hidden_size // self.config.mp_size
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
@ -65,6 +78,14 @@ class DeepSpeedSelfAttention(nn.Module):
self.score_context_func = SoftmaxContextOp(config)
self.linear_func = LinearOp(config)
self.vector_matmul_func = VectorMatMulOp(config)
if len(DeepSpeedSelfAttention._qkv_buffers) == 0:
DeepSpeedSelfAttention._qkv_buffers = [
torch.empty(self.hidden_size_per_partition * 3,
self.config.hidden_size,
dtype=data_type_fp,
device=device),
torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device)
]
def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if isinstance(qkv_out, list):
@ -89,6 +110,18 @@ class DeepSpeedSelfAttention(nn.Module):
context_layer, key_layer, value_layer = attn_key_value
return context_layer, key_layer, value_layer
def _merge_qkv(self):
qvkw = DeepSpeedSelfAttention._qkv_buffers[0]
qvkw[:self.hidden_size_per_partition, :] = self.attn_qw
qvkw[self.hidden_size_per_partition:2 * self.hidden_size_per_partition, :] = self.attn_kw
qvkw[2 * self.hidden_size_per_partition:, :] = self.attn_vw
if self.attn_qb is not None:
qvkb = DeepSpeedSelfAttention._qkv_buffers[1]
qvkb[:self.hidden_size_per_partition] = self.attn_qb
qvkb[self.hidden_size_per_partition:2 * self.hidden_size_per_partition] = self.attn_kb
qvkb[2 * self.hidden_size_per_partition:] = self.attn_vb
return DeepSpeedSelfAttention._qkv_buffers
def forward(self,
input,
input_mask,
@ -101,30 +134,33 @@ class DeepSpeedSelfAttention(nn.Module):
norm_w=None,
norm_b=None,
alibi=None):
if self.attn_qkvw is None:
self._attn_qkvw, self._attn_qkvb = self._merge_qkv()
else:
self._attn_qkvw = self.attn_qkvw
self._attn_qkvb = self.attn_qkvb
if not self.config.pre_layer_norm:
qkv_out = self.linear_func(input=input,
weight=self.attn_qkvw,
bias=self.attn_qkvb,
weight=self._attn_qkvw,
bias=self._attn_qkvb,
add_bias=self.attn_qkvb is not None,
do_flash_attn=False,
num_heads=self.num_attention_heads_per_partition,
num_layers=DeepSpeedSelfAttention.num_layers)
else:
qkv_out = self.qkv_func(input=input,
weight=self.attn_qkvw,
bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
weight=self._attn_qkvw,
bias=(self._attn_qkvb if self._attn_qkvb is not None else norm_b),
gamma=norm_w,
beta=norm_b,
add_bias=(self.attn_qkvb is not None),
num_layers=DeepSpeedSelfAttention.num_layers,
num_heads=self.num_attention_heads_per_partition)
context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
input_mask=input_mask,
layer_past=layer_past,
alibi=alibi)
output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
inp_norm = qkv_out[-1]

View File

@ -20,25 +20,33 @@ class DeepSpeedMLP(nn.Module):
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
data_type_fp = torch.half if config.fp16 else torch.float
device = get_accelerator().current_device_name()
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
intm_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
requires_grad=False)
self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
if self.config.set_empty_params:
self.attn_nw = None
self.attn_nb = None
self.inter_w = None
self.inter_b = None
self.output_w = None
self.output_b = None
else:
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
intm_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
requires_grad=False)
self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
# used for quantization
self.q_scales = q_scales
@ -74,8 +82,6 @@ class DeepSpeedMLP(nn.Module):
final_bias=self.output_b,
add_bias=bias is not None,
residual_add=residual_add)
if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
dist.all_reduce(residual, group=self.mp_group)
return residual

View File

@ -24,5 +24,6 @@ class GELUGemmOp(BaseOp):
weight_out: torch.Tensor,
async_op: bool = False):
output = self.fused_gemm_gelu(input, weight, weight.scale, bias, weight_out, weight_out.scale,
self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op)
self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op,
self.config.transposed_mode)
return output

View File

@ -26,5 +26,6 @@ class LinearOp(BaseOp):
num_heads: int,
external_cache: bool = None,
num_layers: int = None):
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads)
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
self.config.transposed_mode)
return qkv_out

View File

@ -20,8 +20,10 @@ class MLPGemmOp(BaseOp):
def forward(self, input: torch.Tensor, residual: torch.Tensor, input_bias: torch.Tensor,
weight_interm: torch.Tensor, weight_out: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor,
beta: torch.Tensor):
output, residual_add = self.mlp_gemm_func(input, residual, input_bias, weight_interm, weight_out, bias, gamma,
beta, self.config.epsilon, self.config.pre_layer_norm,
self.config.mlp_after_attn, weight_interm.scale, weight_out.scale,
self.config.q_int8, self.config.mlp_act_func_type)
output, residual_add = self.mlp_gemm_func(
input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, self.config.epsilon,
self.config.pre_layer_norm, self.config.mlp_after_attn,
weight_interm.scale if hasattr(weight_interm, 'scale') else torch.empty(1),
weight_out.scale if hasattr(weight_out, 'scale') else torch.empty(1), self.config.q_int8,
self.config.mlp_act_func_type, self.config.transposed_mode)
return output, residual_add

View File

@ -28,10 +28,11 @@ class QKVGemmOp(BaseOp):
num_layers: int,
num_heads: int = None,
max_out_tokens: int = None):
q_scale = weight.scale
q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1)
external_cache = self.config.bigscience_bloom
rank = dist.get_rank() if dist.is_initialized() else 0
q_int8 = self.config.q_int8
output = self.qkv_gemm_func(input, weight, q_scale, bias, gamma, beta, self.config.epsilon, add_bias,
num_layers, external_cache, self.config.mp_size, rank, q_int8)
num_layers, external_cache, self.config.mp_size, rank, q_int8,
self.config.transposed_mode)
return output

View File

@ -18,7 +18,7 @@ class VectorMatMulOp(BaseOp):
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp32
def forward(self, input: torch.Tensor, weight: torch.Tensor, async_op: bool = False):
q_scale = weight.scale
q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1)
q_int8 = self.config.q_int8
output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8)
output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8, self.config.transposed_mode)
return output

View File

@ -31,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..git_version_info import version as __version__
from ..utils import logger
@ -514,6 +515,21 @@ def get_memory_breakdown(param_dict):
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
class HybridEngineConfig(DeepSpeedConfigModel):
enabled: bool = False
max_out_tokens: int = 512
inference_tp_size: int = 1
release_inference_cache: bool = False
pin_parameters: bool = True
tp_gather_partition_size: int = 8
def get_hybrid_engine_config(param_dict):
hybrid_engine_config_dict = param_dict.get("hybrid_engine", {})
hybrid_engine_config = HybridEngineConfig(**hybrid_engine_config_dict)
return hybrid_engine_config
def get_eigenvalue_config(param_dict):
if get_quantize_enabled(param_dict):
param_dict = param_dict[QUANTIZE_TRAINING]
@ -816,6 +832,8 @@ class DeepSpeedConfig(object):
self.eigenvalue_layer_num,
) = get_eigenvalue_config(param_dict)
self.hybrid_engine = get_hybrid_engine_config(param_dict)
self.sparse_attention = get_sparse_attention(param_dict)
self.pipeline = get_pipeline_config(param_dict)

View File

@ -33,7 +33,7 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER
@ -195,7 +195,7 @@ class DeepSpeedEngine(Module):
dist_init_required=None,
collate_fn=None,
config=None,
config_params=None,
config_class=None,
dont_change_device=False,
):
super(DeepSpeedEngine, self).__init__()
@ -213,6 +213,7 @@ class DeepSpeedEngine(Module):
self.gradient_average = True
self.warn_unscaled_loss = True
self.config = config
self._config = config_class
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
@ -242,10 +243,6 @@ class DeepSpeedEngine(Module):
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}
# Set config using config_params for backwards compat
if self.config is None and config_params is not None:
self.config = config_params
from deepspeed.comm import supported_torch_version
# This supported_torch_version check is for torch1.2 compatibility only
if supported_torch_version:
@ -949,19 +946,8 @@ class DeepSpeedEngine(Module):
if hasattr(args, 'local_rank'):
args.local_rank = self.local_rank
if self.config is None:
self.config = (args.deepspeed_config if hasattr(args, "deepspeed_config") else None)
self._config = DeepSpeedConfig(self.config, mpu)
# Validate command line arguments
def _do_args_sanity_check(self, args):
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
if hasattr(args, "deepspeed_config"):
assert (args.deepspeed_config is
None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config
assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \
"variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \
"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
@ -975,10 +961,6 @@ class DeepSpeedEngine(Module):
env_local_rank == args.local_rank
), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
if self.config is None:
assert (hasattr(args, "deepspeed_config") and args.deepspeed_config
is not None), "DeepSpeed requires --deepspeed_config to specify configuration file"
def _is_supported_optimizer(self, optimizer_name):
return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None)

View File

@ -60,6 +60,7 @@ class LossScalerBase:
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
# print(f'LossScalerBackward: {scaled_loss=}')
class LossScaler(LossScalerBase):

View File

@ -0,0 +1,404 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.inference.config import DeepSpeedInferenceConfig
from deepspeed.module_inject.replace_policy import replace_policies
from deepspeed.module_inject.utils import policy_to_ds_container
from .engine import DeepSpeedEngine
from .utils import TLinear, get_inactive_params
from deepspeed.runtime.zero import GatheredParameters
import time
import gc
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from torch import nn
from deepspeed.utils import logger
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
try:
import transformers
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
OPTLearnedPositionalEmbedding = None
inference_cuda_module = None
class DeepSpeedHybridEngine(DeepSpeedEngine):
r"""DeepSpeed engine for training and inference."""
inference_mp_group = None
def __init__(self, args, model, **kwargs):
super().__init__(args, model, **kwargs)
# synch seed between all GPUs
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
dist.broadcast(_rng_state, 0)
get_accelerator().set_rng_state(_rng_state.cpu())
self.Z3_enabled = (self._config.zero_config.stage == 3)
self.gather_all_layers = self._config.hybrid_engine.pin_parameters
# inference containers / fwds
self._inference_containers = []
self._orig_modules = []
self._orig_fwds = []
self.create_inference_module()
# Performance stats
self._t_start = None
self._total_latency = 0
self._iters = 0
self._training_start_time = None
self._generate_latency = 0
self._training_latency = 0
self._total_batch_size = None
self._gather_latency = 0
global inference_cuda_module
if inference_cuda_module is None:
builder = InferenceBuilder()
inference_cuda_module = builder.load()
self.is_lora_fused = False
def convert_to_linear_transposed(self, model):
def _replace_linear_layer(r_module, parent_type=None, prev_type=None):
for name, child in r_module.named_children():
if child.__class__ in [torch.nn.Linear] and \
(parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList):
setattr(r_module, name, TLinear(child, name))
else:
_replace_linear_layer(child, type(r_module), prev_type=parent_type)
return r_module
_replace_linear_layer(model)
def new_inference_container(self, orig_layer, policy_cls, layer_id):
policy = policy_cls(orig_layer, inference=True)
_container = policy_to_ds_container(
policy=policy,
config=DeepSpeedInferenceConfig(set_empty_params=True,
max_out_tokens=self._config.hybrid_engine.max_out_tokens,
min_out_tokens=self._config.hybrid_engine.max_out_tokens,
transposed_mode=True),
model_config=self.module.config if hasattr(self.module, 'config') else None,
layer_id=layer_id,
child=orig_layer)
_container.set_dtype(self._config.fp16_enabled)
_container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group)
_container.initialize_tensors(enable_training=True)
_container.create_ds_model_config()
_container.create_module()
_container.set_params_wo_copy(Z3_enabled=self.Z3_enabled)
return _container
def populate_all_inference_policies(self):
self.inference_policies = {}
for plcy in replace_policies:
_ = plcy(None)
if isinstance(plcy._orig_layer_class, list):
for orig_layer_class in plcy._orig_layer_class:
self.inference_policies.update({orig_layer_class: (self.new_inference_container, plcy)})
elif plcy._orig_layer_class is not None:
self.inference_policies.update({plcy._orig_layer_class: (self.new_inference_container, plcy)})
self.inference_policies.update({
nn.Linear: (LinearLayer, ),
nn.Embedding: (EmbeddingLayer, ),
nn.LayerNorm: (Normalize, ),
OPTLearnedPositionalEmbedding: (OPTEmbedding, )
})
def _fuse_lora(self, params, lora_params):
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
if len(lora_params) > 0:
lora_right_weight, \
lora_left_weight, \
lora_scaling = lora_param
weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
def fuse_lora_weight(self):
for layer_id in range(len(self.layer_params)):
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
def _unfuse_lora(self, params, lora_params):
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
if len(lora_params) > 0:
lora_right_weight, \
lora_left_weight, \
lora_scaling = lora_param
weight.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
def unfuse_lora_weight(self):
for layer_id in range(len(self.layer_params)):
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
def unfuse_lora_weight_non_pinned(self):
for layer_id in range(len(self.layer_params)):
non_active_params = get_inactive_params(self.layer_params[layer_id])
non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
non_active_params.extend(non_active_lora_params)
with GatheredParameters(non_active_params):
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
def retake_inference_cache(self):
if self._config.hybrid_engine.release_inference_cache:
retake_success = inference_cuda_module.retake_workspace()
if not retake_success:
logger.warning("Unable to acquire workspace on first attempt, emtpying cache and retrying.")
gc.collect()
get_accelerator().empty_cache()
retake_success = inference_cuda_module.retake_workspace()
if not retake_success:
raise RuntimeError("Unable to retake inference workspace.")
def generate(self, *inputs, **kwargs):
if self._total_batch_size is None:
bsz = inputs[0].shape[0] if len(inputs) > 0 else \
kwargs['input_ids'].shape[0]
self._total_batch_size = bsz * dist.get_world_size()
self._t0 = time.time()
if self.Z3_enabled and self.gather_all_layers:
if self._config.hybrid_engine.inference_tp_size > 1:
non_tp_params = []
for other_layer in self._other_layers:
non_tp_params.extend(list(other_layer.parameters()))
partition_size = self._config.hybrid_engine.tp_gather_partition_size
layer_groups = len(self.layer_params) // partition_size
for lg in range(layer_groups):
non_active_params = []
non_active_lora_params = []
for layer_id in range(lg * partition_size, min(len(self.layer_params), (lg + 1) * partition_size),
1):
non_tp_params.extend(self.layer_params[layer_id][:4])
non_active_params.extend(get_inactive_params(self.layer_params[layer_id]))
non_active_params.extend(get_inactive_params(self.layer_lora_params[layer_id]))
with GatheredParameters(non_active_params):
for layer_id in range(lg * partition_size,
min(len(self.layer_params), (lg + 1) * partition_size), 1):
if len(self.all_lora_params) > 0:
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
self._inference_containers[layer_id].apply_tensor_parallelism(
mp_group=self.mp_group, tp_size=self._config.hybrid_engine.inference_tp_size)
# TODO(cmikeh2) Evaluate if this can be deferred when release_inference_cache
# is enabled.
gc.collect()
get_accelerator().empty_cache()
self._gather_latency = time.time() - self._t0
input_shape = inputs[0].shape if len(inputs) > 0 else \
kwargs['input_ids'].shape
output = torch.zeros(
(input_shape[0] * self._config.hybrid_engine.inference_tp_size, ) + input_shape[1:],
dtype=inputs[0].dtype if len(inputs) > 0 else kwargs['input_ids'].dtype,
device=inputs[0].device if len(inputs) > 0 else kwargs['input_ids'].device)
input_cont = inputs[0].contiguous() if len(inputs) > 0 else kwargs['input_ids'].contiguous()
dist.all_gather_into_tensor(output, input_cont, group=self.mp_group)
if len(inputs) > 0:
inputs = (output, )
else:
kwargs['input_ids'] = output
self.retake_inference_cache()
non_active_params = get_inactive_params(non_tp_params)
with GatheredParameters(non_active_params):
generate_ret_vals = self._generate(*inputs, **kwargs)
for layer_id in range(len(self.layer_params)):
self._inference_containers[layer_id].release_memory()
rank = dist.get_rank(group=self.mp_group)
generate_ret_vals = generate_ret_vals[input_shape[0] * rank:input_shape[0] * (rank + 1)]
else:
non_active_layers = get_inactive_params(self.all_layers_params)
non_active_lora_params = get_inactive_params(self.all_lora_params)
non_active_layers.extend(non_active_lora_params)
with GatheredParameters(non_active_layers):
self._gather_latency = time.time() - self._t0
if len(self.all_lora_params) > 0:
self.fuse_lora_weight()
self.retake_inference_cache()
generate_ret_vals = self._generate(*inputs, **kwargs)
if len(self.all_lora_params) > 0:
self.unfuse_lora_weight()
else:
if len(self.all_lora_params) > 0 and (not self.Z3_enabled):
self.fuse_lora_weight()
self.retake_inference_cache()
generate_ret_vals = self._generate(*inputs, **kwargs)
if len(self.all_lora_params) > 0:
if (not self.Z3_enabled):
self.unfuse_lora_weight()
else:
self.unfuse_lora_weight_non_pinned()
self.is_lora_fused = False
if self._config.hybrid_engine.release_inference_cache:
inference_cuda_module.release_workspace()
gc.collect()
get_accelerator().empty_cache()
self._generate_latency = time.time() - self._t0 - self._gather_latency
return generate_ret_vals
def create_inference_containers(self, module, layer_id=0):
for name, child in module.named_children():
if child.__class__ in self.inference_policies:
if self.inference_policies[child.__class__][0] == self.new_inference_container:
self._inference_containers.append(self.inference_policies[child.__class__][0](
child, self.inference_policies[child.__class__][-1], layer_id))
self._orig_modules.append(child)
self._orig_fwds.append(child.forward)
self.layer_params.append(self._inference_containers[layer_id].get_all_params())
self.lora_params.append(self._inference_containers[layer_id].get_lora_params())
self.layer_lora_params.append([])
for lora_param in self.lora_params[layer_id]:
self.layer_lora_params[layer_id].extend(lora_param[:-1])
self.all_lora_params.extend(lora_param[:-1])
layer_id += 1
else:
self._other_layers.append(self.inference_policies[child.__class__][0](
weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
self._orig_modules_others.append(child)
self._orig_fwds_others.append(child.forward)
else:
self.create_inference_containers(child, layer_id=layer_id)
def create_inference_module(self):
self.layer_params = []
self.layer_lora_params = []
self.lora_params = []
self.all_lora_params = []
self._other_layers = []
self._orig_modules_others = []
self._orig_fwds_others = []
if self._config.hybrid_engine.inference_tp_size > 1:
global_rank = dist.get_rank()
world_size = dist.get_world_size()
mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size
num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size
for mp_group_id in range(num_mp_groups):
ranks = list(
range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \
(mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \
1)
)
mp_group = dist.new_group(ranks)
if global_rank in ranks:
self.mp_group = mp_group
else:
self.mp_group = None
self.populate_all_inference_policies()
self.all_layers_params = list(self.module.parameters())
self.create_inference_containers(self.module)
self._generate = self.module.generate
self.module.generate = self.generate
self._t0 = time.time()
def _zero3_forward(self, layer_id):
def run_forward(*inputs, **kwargs):
non_active_params = get_inactive_params(self.layer_params[layer_id])
non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
non_active_params.extend(non_active_lora_params)
with GatheredParameters(non_active_params):
if len(self.all_lora_params) > 0:
# Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory
if not self.is_lora_fused:
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
# Set the is_lora_fused to true when reaching the last layer
if layer_id == len(self.layer_params) - 1:
self.is_lora_fused = True
return self._inference_containers[layer_id].module.forward(*inputs, **kwargs)
return run_forward
def eval(self):
if self._t_start is not None:
latency = time.time() - self._t_start
self._total_latency = self._total_latency + latency
self._iters = self._iters + 1
if not dist.is_initialized() or dist.get_rank() == 0:
others = latency - (self._generate_latency + self._training_latency)
print(f'|E2E latency={(latency):.2f}s ' + \
f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) '
f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \
f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \
f'|Others={others:.2f} ({(others / latency * 100):.2f}%)'
f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + \
f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}')
self._t_start = time.time()
self._training_latency = 0
super().eval()
if len(self._inference_containers) > 0:
for i, (orig_module, inference_container) in enumerate(zip(self._orig_modules,
self._inference_containers)):
if self.Z3_enabled and not self.gather_all_layers:
orig_module.forward = self._zero3_forward(i)
else:
orig_module.forward = inference_container.module.forward
if not self.Z3_enabled or self.gather_all_layers:
for orig_module, inference_layer in zip(self._orig_modules_others, self._other_layers):
orig_module.forward = inference_layer.forward
if self.Z3_enabled:
gc.collect()
get_accelerator().empty_cache()
if self._t_start is None:
self._t_start = time.time()
def train(self, mode=True):
if mode and len(self._orig_modules) > 0:
for orig_module, orig_fwd in zip(self._orig_modules, self._orig_fwds):
orig_module.forward = orig_fwd
for orig_module, orig_fwd in zip(self._orig_modules_others, self._orig_fwds_others):
orig_module.forward = orig_fwd
super().train(mode)
if mode:
self._training_start_time = time.time()
def step(self, lr_kwargs=None):
super().step(lr_kwargs=lr_kwargs)
if(self._inference_containers[0].module.attention.attn_qkvw is not None and \
self._inference_containers[0].q_k_v is not None):
for inference_container in self._inference_containers:
inference_container.reset_qkv()
if self._training_start_time is not None:
self._training_latency += (time.time() - self._training_start_time)
self._training_start_time = time.time()

View File

@ -30,6 +30,9 @@ from deepspeed.runtime.constants import PIPE_REPLICATED
from numpy import prod
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.policy import transpose
from torch.nn import functional as F
torch_memory_reserved = get_accelerator().memory_reserved
torch_max_memory_reserved = get_accelerator().max_memory_reserved
@ -300,6 +303,7 @@ def get_global_norm(norm_list):
total_norm = 0.0
for norm in norm_list:
total_norm += norm**2.0
# logger.info(f'norm_list = {norm_list} global = {sqrt(total_norm)}')
return sqrt(total_norm)
@ -944,3 +948,28 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align
shard_list.append(curr_shard)
dist.all_gather(shard_list, shard_list[partition_id], dp_process_group[group_id])
class TLinear(torch.nn.Linear):
def __init__(self, orig_layer, name=""):
self.name = name
super().__init__(orig_layer.weight.shape[1], orig_layer.weight.shape[0], bias=(orig_layer.bias is not None))
self.weight.data = transpose(orig_layer.weight.data)
self.bias = orig_layer.bias
self._fwd_func = self._fwd_bias_add if self.bias is not None else self._fwd
def _fwd(self, input):
return F.linear(input, self.weight)
def _fwd_bias_add(self, input):
return F.linear(input, self.weight, bias=self.bias)
def forward(self, input):
return self._fwd_func(input)
def get_inactive_params(param_list):
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
return [param for param in param_list if (hasattr(param, 'ds_id') and \
param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]

View File

@ -34,7 +34,8 @@ ZeRO optimization should be enabled as:
"offload_param": {...},
"offload_optimizer": {...},
"ignore_unused_parameters": [true|false],
"round_robin_gradients": [true|false]
"round_robin_gradients": [true|false],
"memory_efficient_linear": [true|false]
}
}
"""
@ -248,6 +249,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
between optimizer steps) or GPU count (increased parallelism).
"""
memory_efficient_linear: bool = True
"""
Use memory efficient linear implementation, for Stage 3.
"""
# Validators
@validator("overlap_comm")
def overlap_comm_valid(cls, field_value, values):

View File

@ -49,7 +49,6 @@ class LinearFunctionForZeroStage3(torch.autograd.Function):
@autocast_custom_fwd
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
#print("In ZeRO Linear Function")
weight_id = id(weight)
bias_id = id(bias)

View File

@ -461,7 +461,8 @@ class DeepSpeedZeRoOffload(object):
@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
param_coordinator = self.get_param_coordinator(training=sub_module.training)
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator(training=True)
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
@ -469,11 +470,12 @@ class DeepSpeedZeRoOffload(object):
@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
self.get_param_coordinator(training=sub_module.training).release_sub_module(sub_module)
self.get_param_coordinator(training=True).release_sub_module(sub_module)
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",

View File

@ -11,7 +11,6 @@ from enum import Enum
import functools
import itertools
from typing import List
import torch
from torch import Tensor
from deepspeed import comm as dist
@ -688,9 +687,10 @@ class Init(InsertPostInitMethodToModuleSubClasses):
config_dict_or_path = config
logger.warning(
f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.')
_ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path,
mpu) if config_dict_or_path is not None else None
if _ds_config is not None:
mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear
super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype)
if not dist.is_initialized():
init_distributed()

View File

@ -138,10 +138,20 @@ class PartitionedParameterCoordinator:
def trace_prologue(self, sub_module: Module) -> None:
if self.is_complete_trace():
# sub_module must match expectation else invalidate trace cache
if len(self.__submodule_order) <= self.__step_id:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: "
f"cache has only {len(self.__submodule_order)} modules",
force=True)
self._invalidate_trace()
return
if sub_module != self.__submodule_order[self.__step_id]:
expected_module_id = self.__submodule_order[self.__step_id].id
debug_rank0(f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.id}")
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.id}",
force=True)
self._invalidate_trace()
def record_module(self, sub_module: Module) -> None:
@ -187,7 +197,9 @@ class PartitionedParameterCoordinator:
self.__submodule_order = tuple(self.__submodule_order) # freeze
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = ZeRoTraceMode.COMPLETE
print_rank_0(f"completed record trace: {[m.id for m in self.__submodule_order]}", force=False)
print_rank_0(
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}",
force=False)
else:
# Enable trace recording for next forward/backward pass
self.__trace_mode = ZeRoTraceMode.RECORD

View File

@ -33,7 +33,7 @@ pg_correctness_test = False
def print_rank_0(message, debug=False, force=False):
rank = dist.get_rank()
if rank == 0 and (debug or force):
print(message)
logger.info(message)
# other variations
# - print for all ranks w/o interleaving
# printflock(f"[{rank}] {message}")
@ -1015,6 +1015,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.__reduce_and_partition_ipg_grads()
param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \

View File

@ -0,0 +1,19 @@
{
"train_batch_size" : 32,
"train_micro_batch_size_per_gpu": 2,
"steps_per_print": 10,
"zero_optimization": {
"stage": 0,
"offload_param": {
"device": "cpu"
},
"stage3_param_persistence_threshold": 0
},
"fp16":{
"enabled": true,
"loss_scale_window": 100
},
"gradient_clipping": 1.0,
"prescale_gradients": false,
"wall_clock_breakdown" : false
}

View File

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from transformers import AutoModelForCausalLM
import deepspeed
import argparse
from deepspeed.accelerator import get_accelerator
deepspeed.runtime.utils.see_memory_usage('pre test', force=True)
model = AutoModelForCausalLM.from_pretrained('facebook/opt-350M').half().to(get_accelerator().device_name())
parser = argparse.ArgumentParser()
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
deepspeed.runtime.utils.see_memory_usage('post test', force=True)
m, _, _, _ = deepspeed.initialize(model=model, args=args, enable_hybrid_engine=True)
m.eval()
input = torch.ones(1, 16, device='cuda', dtype=torch.long)
out = m(input)
m.train()
out = m(input)
print(out['logits'], out['logits'].norm())

View File

@ -1 +1 @@
0.8.3
0.9.0