mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
DeepSpeed Chat (#3186)
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Co-authored-by: yaozhewei <zheweiy@berkeley.edu> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Connor Holmes <connorholmes@microsoft.com> Co-authored-by: Lok Chand Koppaka <lokoppak@microsoft.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
This commit is contained in:
@ -26,7 +26,6 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
||||
# put all valid class name <--> class type mapping into class_dict
|
||||
op_builder_dir = self.op_builder_dir()
|
||||
op_builder_module = importlib.import_module(op_builder_dir)
|
||||
|
||||
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
|
||||
# avoid self references
|
||||
if module_name != 'all_ops' and module_name != 'builder':
|
||||
|
@ -44,9 +44,9 @@ inline int DS_GET_BLOCKS(const int N)
|
||||
1);
|
||||
}
|
||||
|
||||
class Context {
|
||||
class TrainingContext {
|
||||
public:
|
||||
Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
|
||||
TrainingContext() : _workspace(nullptr), _seed(42), _curr_offset(0)
|
||||
{
|
||||
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
|
||||
curandSetPseudoRandomGeneratorSeed(_gen, 123);
|
||||
@ -57,15 +57,15 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~Context()
|
||||
virtual ~TrainingContext()
|
||||
{
|
||||
cublasDestroy(_cublasHandle);
|
||||
cudaFree(_workspace);
|
||||
}
|
||||
|
||||
static Context& Instance()
|
||||
static TrainingContext& Instance()
|
||||
{
|
||||
static Context _ctx;
|
||||
static TrainingContext _ctx;
|
||||
return _ctx;
|
||||
}
|
||||
|
||||
|
@ -39,8 +39,8 @@ public:
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = Context::Instance().GetCurrentStream();
|
||||
_streams[1] = Context::Instance().GetNewStream();
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
|
@ -54,8 +54,8 @@ public:
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = Context::Instance().GetCurrentStream();
|
||||
_streams[1] = Context::Instance().GetNewStream();
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
|
@ -457,7 +457,7 @@ void launch_sr_fake_quantize_kernel(T* vals,
|
||||
dim3 grid_dim(group_num);
|
||||
|
||||
uint64_t inc = total_count / grid_dim.x / block_dim.x;
|
||||
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
|
||||
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
|
||||
|
||||
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
|
||||
@ -1011,7 +1011,7 @@ void launch_sr_fake_quantize_kernel_asym(T* vals,
|
||||
dim3 grid_dim(group_num);
|
||||
|
||||
uint64_t inc = total_count / grid_dim.x / block_dim.x;
|
||||
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
|
||||
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
|
||||
|
||||
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
|
||||
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
|
||||
|
@ -278,7 +278,7 @@ void launch_dropout(T* out,
|
||||
grid_dim.x <<= 1;
|
||||
}
|
||||
uint64_t inc = total_count / grid_dim.x / block_dim.x;
|
||||
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
|
||||
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
|
||||
if (bwd)
|
||||
dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
|
||||
total_count, ratio, vals, out, mask, seed);
|
||||
@ -625,7 +625,7 @@ void launch_dropout(T* out,
|
||||
dim3 block_dim = DS_CUDA_NUM_THREADS;
|
||||
|
||||
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
|
||||
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
|
||||
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
|
||||
|
||||
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
|
||||
total_count, dim, ratio, bias, out, mask, seed);
|
||||
@ -847,7 +847,7 @@ void launch_dropout(T* out,
|
||||
dim3 block_dim = DS_CUDA_NUM_THREADS;
|
||||
|
||||
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
|
||||
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
|
||||
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
|
||||
|
||||
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
|
||||
total_count, dim, ratio, input, residual, bias, out, mask, seed);
|
||||
|
@ -78,8 +78,8 @@ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
|
||||
_normalize_invertible(normalize_invertible),
|
||||
_gelu_checkpoint(gelu_checkpoint),
|
||||
_stochastic_mode(stochastic_mode),
|
||||
_stream(Context::Instance().GetCurrentStream()),
|
||||
_cublasHandle(Context::Instance().GetCublasHandle()),
|
||||
_stream(TrainingContext::Instance().GetCurrentStream()),
|
||||
_cublasHandle(TrainingContext::Instance().GetCublasHandle()),
|
||||
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
|
||||
3 * hidden_size,
|
||||
hidden_size,
|
||||
@ -183,7 +183,7 @@ void BertTransformerLayer<T>::Forward(unsigned bsz,
|
||||
|
||||
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
|
||||
|
||||
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
|
||||
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
|
||||
size_t small_buf_size = bsz * _seq_length * _hidden_size;
|
||||
T* buf_0 = workspace;
|
||||
T* buf_1 = buf_0 + small_buf_size;
|
||||
@ -343,7 +343,7 @@ void BertTransformerLayer<T>::Backward(unsigned bsz,
|
||||
|
||||
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
|
||||
|
||||
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
|
||||
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
|
||||
size_t small_buf_size = bsz * _seq_length * _hidden_size;
|
||||
T* buf_0 = workspace;
|
||||
T* buf_1 = buf_0 + small_buf_size;
|
||||
@ -609,25 +609,26 @@ int create_transformer_layer(unsigned layer_id,
|
||||
bool gelu_checkpoint,
|
||||
bool stochastic_mode)
|
||||
{
|
||||
Context::Instance().SetSeed(seed);
|
||||
Context::Instance().TestGemmFP16(
|
||||
TrainingContext::Instance().SetSeed(seed);
|
||||
TrainingContext::Instance().TestGemmFP16(
|
||||
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
|
||||
|
||||
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
|
||||
batch_size,
|
||||
hidden_dim,
|
||||
num_heads,
|
||||
intermediate_size,
|
||||
init_seq_length,
|
||||
attn_dropout_ratio,
|
||||
hidden_dropout_ratio,
|
||||
layer_norm_eps,
|
||||
pre_or_postLayerNorm,
|
||||
Context::Instance().GetGemmAlgos(),
|
||||
attn_dropout_checkpoint,
|
||||
normalize_invertible,
|
||||
gelu_checkpoint,
|
||||
stochastic_mode);
|
||||
auto layer =
|
||||
std::make_shared<BertTransformerLayer<T>>(layer_id,
|
||||
batch_size,
|
||||
hidden_dim,
|
||||
num_heads,
|
||||
intermediate_size,
|
||||
init_seq_length,
|
||||
attn_dropout_ratio,
|
||||
hidden_dropout_ratio,
|
||||
layer_norm_eps,
|
||||
pre_or_postLayerNorm,
|
||||
TrainingContext::Instance().GetGemmAlgos(),
|
||||
attn_dropout_checkpoint,
|
||||
normalize_invertible,
|
||||
gelu_checkpoint,
|
||||
stochastic_mode);
|
||||
|
||||
s_transformer_layers[layer_id] = layer;
|
||||
|
||||
@ -725,7 +726,7 @@ std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
|
||||
layer->IsTrainingMode(),
|
||||
layer->GeluCheckpoint())},
|
||||
options);
|
||||
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
|
||||
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
|
||||
|
||||
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
|
||||
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
|
||||
@ -909,7 +910,7 @@ std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
|
||||
layer->IsTrainingMode(),
|
||||
layer->GeluCheckpoint())},
|
||||
options);
|
||||
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
|
||||
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
|
||||
|
||||
auto grad_input = torch::empty_like(input);
|
||||
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
|
||||
|
@ -96,7 +96,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
|
||||
head_offset,
|
||||
mask_stride,
|
||||
mp_size,
|
||||
Context::Instance().GetCurrentStream(async_op));
|
||||
InferenceContext::Instance().GetCurrentStream(async_op));
|
||||
|
||||
return attn_scores_c;
|
||||
}
|
||||
@ -110,18 +110,20 @@ void allocate_workspace(unsigned hidden_dim,
|
||||
unsigned mp_size = 1,
|
||||
bool external_cache = false,
|
||||
unsigned rank = 0,
|
||||
unsigned max_out_tokens = 1024)
|
||||
unsigned max_out_tokens = 1024,
|
||||
unsigned min_out_tokens = 1)
|
||||
{
|
||||
Context::Instance().GenWorkSpace(num_layers,
|
||||
num_heads,
|
||||
batch_size,
|
||||
prompt_length,
|
||||
hidden_dim,
|
||||
mp_size,
|
||||
external_cache,
|
||||
sizeof(T),
|
||||
rank,
|
||||
max_out_tokens);
|
||||
InferenceContext::Instance().GenWorkSpace(num_layers,
|
||||
num_heads,
|
||||
batch_size,
|
||||
prompt_length,
|
||||
hidden_dim,
|
||||
mp_size,
|
||||
external_cache,
|
||||
sizeof(T),
|
||||
rank,
|
||||
max_out_tokens,
|
||||
min_out_tokens);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -132,15 +134,15 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
|
||||
.layout(at::kStrided)
|
||||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
float alpha = 1;
|
||||
float gemm_beta = 0.0;
|
||||
|
||||
/*
|
||||
// Reallocate memory if we received a new prompt
|
||||
if (!workspace || input.size(1) != 1) {
|
||||
allocate_workspace<T>(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1,
|
||||
head_size); workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
allocate_workspace<T>(W.size(1), InferenceContext::Instance().GetMaxTokenLenght(),
|
||||
Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
}
|
||||
*/
|
||||
|
||||
@ -148,7 +150,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
|
||||
unsigned m = W.size(1);
|
||||
unsigned n = Q.size(1) * Q.size(2);
|
||||
unsigned k = Q.size(0);
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_T,
|
||||
m,
|
||||
@ -195,8 +197,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
||||
|
||||
auto mask_stride = get_attn_mask_stride(attn_mask);
|
||||
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
|
||||
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
|
||||
soft_len,
|
||||
seq_len,
|
||||
k,
|
||||
@ -231,9 +234,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
||||
0,
|
||||
mask_stride,
|
||||
1,
|
||||
Context::Instance().GetCurrentStream(false));
|
||||
InferenceContext::Instance().GetCurrentStream(false));
|
||||
alpha = 1.0;
|
||||
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
|
||||
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
|
||||
k,
|
||||
seq_len,
|
||||
soft_len,
|
||||
@ -364,10 +367,11 @@ void attention_unfused(T* prev_key_cont,
|
||||
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
|
||||
float alpha = norm_factor * norm_factor / layer_scale;
|
||||
float gemm_beta = 0.0;
|
||||
T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace();
|
||||
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
|
||||
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
|
||||
soft_len,
|
||||
seq_len,
|
||||
k,
|
||||
@ -378,7 +382,7 @@ void attention_unfused(T* prev_key_cont,
|
||||
workspace,
|
||||
CUBLAS_OP_T,
|
||||
CUBLAS_OP_N,
|
||||
Context::Instance().GetMaxTokenLenght() * k,
|
||||
InferenceContext::Instance().GetMaxTokenLenght() * k,
|
||||
seq_len * k,
|
||||
seq_len * soft_len,
|
||||
bsz * heads,
|
||||
@ -400,7 +404,7 @@ void attention_unfused(T* prev_key_cont,
|
||||
soft_len,
|
||||
heads);
|
||||
alpha = 1.0;
|
||||
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
|
||||
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
|
||||
k,
|
||||
seq_len,
|
||||
soft_len,
|
||||
@ -411,7 +415,7 @@ void attention_unfused(T* prev_key_cont,
|
||||
(T*)output,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
Context::Instance().GetMaxTokenLenght() * k,
|
||||
InferenceContext::Instance().GetMaxTokenLenght() * k,
|
||||
seq_len * soft_len,
|
||||
seq_len * k,
|
||||
bsz * heads,
|
||||
@ -422,7 +426,7 @@ void attention_unfused(T* prev_key_cont,
|
||||
#endif
|
||||
}
|
||||
|
||||
void reset_cache() { Context::Instance().reset_tokens(); }
|
||||
void reset_cache() { InferenceContext::Instance().reset_tokens(); }
|
||||
|
||||
template <typename T>
|
||||
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
||||
@ -446,8 +450,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
||||
|
||||
bool is_prompt = (seq_len > 1);
|
||||
|
||||
if (is_prompt) Context::Instance().reset_tokens(seq_len);
|
||||
unsigned soft_len = Context::Instance().current_tokens();
|
||||
if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
|
||||
unsigned soft_len = InferenceContext::Instance().current_tokens();
|
||||
|
||||
int k = hidden_dim / heads;
|
||||
auto options = at::TensorOptions()
|
||||
@ -456,16 +460,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
||||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
size_t buf_size = bsz * seq_len * hidden_dim;
|
||||
auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options);
|
||||
auto output = torch::from_blob(workspace + 3 * buf_size, {bsz, seq_len, hidden_dim}, options);
|
||||
|
||||
auto query_cont = workspace + 8 * buf_size;
|
||||
size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) +
|
||||
layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
|
||||
auto query_cont = workspace + 4 * buf_size;
|
||||
size_t offset =
|
||||
10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLenght()) +
|
||||
layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
|
||||
unsigned all_tokens = soft_len;
|
||||
auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1);
|
||||
size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
|
||||
size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
|
||||
|
||||
T* temp_buf = (T*)output.data_ptr() + at::numel(output);
|
||||
launch_bias_add_transform_0213<T>((T*)query_cont,
|
||||
@ -482,9 +487,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
||||
rotary_dim,
|
||||
rotate_half,
|
||||
rotate_every_two,
|
||||
Context::Instance().GetCurrentStream(),
|
||||
InferenceContext::Instance().GetCurrentStream(),
|
||||
3,
|
||||
Context::Instance().GetMaxTokenLenght());
|
||||
InferenceContext::Instance().GetMaxTokenLenght());
|
||||
if (rotary_dim > 0 && rotate_half)
|
||||
launch_apply_rotary_pos_emb(query_cont,
|
||||
kv_cache,
|
||||
@ -496,8 +501,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
||||
bsz,
|
||||
rotate_half,
|
||||
rotate_every_two,
|
||||
Context::Instance().GetCurrentStream(),
|
||||
Context::Instance().GetMaxTokenLenght());
|
||||
InferenceContext::Instance().GetCurrentStream(),
|
||||
InferenceContext::Instance().GetMaxTokenLenght());
|
||||
|
||||
attention_unfused<T>(workspace + offset,
|
||||
(T*)query_cont,
|
||||
@ -522,25 +527,26 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
||||
heads,
|
||||
seq_len,
|
||||
output.size(2),
|
||||
Context::Instance().GetCurrentStream(false),
|
||||
InferenceContext::Instance().GetCurrentStream(false),
|
||||
1);
|
||||
|
||||
if (layer_id == num_layers - 1) Context::Instance().advance_tokens();
|
||||
if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
|
||||
auto prev_key = torch::from_blob(workspace + offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * Context::Instance().GetMaxTokenLenght(),
|
||||
k * Context::Instance().GetMaxTokenLenght(),
|
||||
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
|
||||
k * InferenceContext::Instance().GetMaxTokenLenght(),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
|
||||
auto prev_value = torch::from_blob(workspace + offset + value_offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * Context::Instance().GetMaxTokenLenght(),
|
||||
k * Context::Instance().GetMaxTokenLenght(),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
auto prev_value =
|
||||
torch::from_blob(workspace + offset + value_offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
|
||||
k * InferenceContext::Instance().GetMaxTokenLenght(),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
|
||||
return {output, prev_key, prev_value};
|
||||
}
|
||||
@ -557,7 +563,7 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
|
||||
(T*)bias.data_ptr(),
|
||||
intermediate_size,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return input_cont;
|
||||
}
|
||||
|
||||
@ -583,14 +589,14 @@ at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias)
|
||||
(const float*)bias.data_ptr(),
|
||||
rows,
|
||||
channels,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
} else {
|
||||
launch_fused_bias_geglu((__half*)output.data_ptr(),
|
||||
(const __half*)activation.data_ptr(),
|
||||
(const __half*)bias.data_ptr(),
|
||||
rows,
|
||||
channels,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
}
|
||||
|
||||
return output;
|
||||
@ -608,7 +614,7 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
|
||||
(T*)bias.data_ptr(),
|
||||
intermediate_size,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return input_cont;
|
||||
}
|
||||
|
||||
@ -624,7 +630,7 @@ at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias)
|
||||
(T*)bias.data_ptr(),
|
||||
hidden_size,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return input_cont;
|
||||
}
|
||||
|
||||
@ -641,7 +647,7 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor&
|
||||
// bsz,
|
||||
// input_cont.size(2),
|
||||
// (bias.size(0) > 1),
|
||||
// Context::Instance().GetCurrentStream());
|
||||
// InferenceContext::Instance().GetCurrentStream());
|
||||
return input_cont;
|
||||
}
|
||||
|
||||
@ -659,7 +665,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
|
||||
epsilon,
|
||||
rows,
|
||||
elems_per_row,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
} else {
|
||||
launch_fused_ln((float*)output.data_ptr(),
|
||||
(const float*)input.data_ptr(),
|
||||
@ -668,7 +674,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
|
||||
epsilon,
|
||||
rows,
|
||||
elems_per_row,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
}
|
||||
|
||||
return output;
|
||||
@ -689,7 +695,7 @@ void ds_layer_norm_internal(T* workspace,
|
||||
epsilon,
|
||||
bsz,
|
||||
input.size(2),
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
}
|
||||
|
||||
/* Currently only used in unit testing */
|
||||
@ -714,7 +720,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
|
||||
epsilon,
|
||||
rows,
|
||||
elems_per_row,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
} else {
|
||||
launch_fused_residual_ln((float*)output.data_ptr(),
|
||||
(const float*)input.data_ptr(),
|
||||
@ -725,7 +731,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
|
||||
epsilon,
|
||||
rows,
|
||||
elems_per_row,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
}
|
||||
|
||||
return output;
|
||||
@ -755,7 +761,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
|
||||
epsilon,
|
||||
rows,
|
||||
elems_per_row,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
} else {
|
||||
launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(),
|
||||
(float*)res_output.data_ptr(),
|
||||
@ -767,7 +773,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
|
||||
epsilon,
|
||||
rows,
|
||||
elems_per_row,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
}
|
||||
|
||||
return {norm_output, res_output};
|
||||
@ -782,7 +788,7 @@ void quantized_gemm(void* output,
|
||||
int bsz,
|
||||
int hidden_size)
|
||||
{
|
||||
// T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
|
||||
// T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
|
||||
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(at::kHalf)
|
||||
@ -797,11 +803,11 @@ void quantized_gemm(void* output,
|
||||
weight.size(0),
|
||||
weight.size(1),
|
||||
groups,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_T,
|
||||
CUBLAS_OP_N,
|
||||
weight.size(0),
|
||||
@ -829,10 +835,11 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
|
||||
at::Tensor& beta,
|
||||
const float epsilon,
|
||||
bool add_bias,
|
||||
bool q_int8)
|
||||
bool q_int8,
|
||||
bool transposed_mode)
|
||||
{
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
workspace += (3 * bsz * input.size(2));
|
||||
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
|
||||
|
||||
@ -843,12 +850,12 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(),
|
||||
Context::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
weight.size(1),
|
||||
weight.size(transposed_mode ? 0 : 1),
|
||||
bsz,
|
||||
input.size(2),
|
||||
&alpha,
|
||||
@ -865,9 +872,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
|
||||
if (add_bias)
|
||||
launch_bias_add((T*)output.data_ptr(),
|
||||
(T*)bias.data_ptr(),
|
||||
q_int8 ? weight.size(0) : weight.size(1),
|
||||
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return torch::from_blob(workspace, input.sizes(), input.options());
|
||||
}
|
||||
|
||||
@ -884,11 +891,12 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
|
||||
bool external_cache,
|
||||
unsigned mp_size,
|
||||
unsigned rank,
|
||||
bool q_int8)
|
||||
bool q_int8,
|
||||
bool transposed_mode)
|
||||
{
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
int out_size = q_int8 ? weight.size(0) : weight.size(1);
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
|
||||
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(input.options().dtype())
|
||||
@ -897,8 +905,17 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
|
||||
.requires_grad(false);
|
||||
|
||||
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
|
||||
auto inp_norm = qkv_unfused_cublas<T>(
|
||||
output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8);
|
||||
auto inp_norm = qkv_unfused_cublas<T>(output,
|
||||
input,
|
||||
weight,
|
||||
q_scale,
|
||||
bias,
|
||||
gamma,
|
||||
beta,
|
||||
epsilon,
|
||||
add_bias,
|
||||
q_int8,
|
||||
transposed_mode);
|
||||
|
||||
return {output, inp_norm};
|
||||
}
|
||||
@ -926,11 +943,11 @@ void quantized_gemm(at::Tensor& output,
|
||||
weight.size(1),
|
||||
groups,
|
||||
merge_count,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_T,
|
||||
CUBLAS_OP_N,
|
||||
weight.size(0),
|
||||
@ -977,7 +994,7 @@ at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
return output;
|
||||
}
|
||||
@ -988,7 +1005,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
|
||||
at::Tensor& bias,
|
||||
bool add_bias,
|
||||
bool do_flash_attn,
|
||||
int num_heads)
|
||||
int num_heads,
|
||||
bool transposed_mode)
|
||||
{
|
||||
auto input_cont = input.contiguous();
|
||||
auto options = at::TensorOptions()
|
||||
@ -999,17 +1017,18 @@ at::Tensor ds_linear_layer(at::Tensor& input,
|
||||
|
||||
int head_size = input_cont.size(2) / num_heads;
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
|
||||
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
weight.size(1),
|
||||
weight.size(transposed_mode ? 0 : 1),
|
||||
bsz,
|
||||
input_cont.size(2),
|
||||
&alpha,
|
||||
@ -1025,9 +1044,9 @@ at::Tensor ds_linear_layer(at::Tensor& input,
|
||||
if (add_bias)
|
||||
launch_bias_add((T*)output.data_ptr(),
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
weight.size(transposed_mode ? 0 : 1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0);
|
||||
if (do_flash_attn) {
|
||||
if (add_padding) {
|
||||
@ -1040,7 +1059,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
|
||||
3 * bsz * num_heads,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
launch_bias_add_transform_0213<T>(
|
||||
final_output,
|
||||
@ -1057,7 +1076,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
|
||||
-1,
|
||||
false,
|
||||
false,
|
||||
Context::Instance().GetCurrentStream(),
|
||||
InferenceContext::Instance().GetCurrentStream(),
|
||||
3,
|
||||
input.size(1));
|
||||
return at::from_blob(final_output,
|
||||
@ -1082,7 +1101,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
|
||||
-1,
|
||||
false,
|
||||
false,
|
||||
Context::Instance().GetCurrentStream(),
|
||||
InferenceContext::Instance().GetCurrentStream(),
|
||||
3,
|
||||
input.size(1));
|
||||
return at::from_blob(
|
||||
@ -1100,7 +1119,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
|
||||
{
|
||||
int head_size = query.size(3);
|
||||
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2);
|
||||
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128;
|
||||
pad_head_seq(workspace,
|
||||
@ -1110,7 +1129,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
|
||||
query.size(2),
|
||||
head_size,
|
||||
padded_head_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
pad_head_seq(key_pad_ptr,
|
||||
(T*)key.data_ptr(),
|
||||
query.size(0) * query.size(1),
|
||||
@ -1118,7 +1137,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
|
||||
128,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
pad_head_seq(value_pad_ptr,
|
||||
(T*)value.data_ptr(),
|
||||
query.size(0) * query.size(1),
|
||||
@ -1126,7 +1145,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
|
||||
128,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return {
|
||||
at::from_blob(workspace,
|
||||
{query.size(0), query.size(1), query.size(2), padded_head_size},
|
||||
@ -1148,7 +1167,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
|
||||
int key_value_length = add_padding ? 128 : key.size(1);
|
||||
int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128))
|
||||
: head_size;
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1);
|
||||
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length;
|
||||
launch_pad_add_transform_0213(workspace,
|
||||
@ -1159,7 +1178,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
|
||||
query.size(1),
|
||||
heads,
|
||||
padded_head_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
launch_pad_add_transform_0213(key_pad_ptr,
|
||||
(T*)key.data_ptr(),
|
||||
key.size(0),
|
||||
@ -1168,7 +1187,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
|
||||
key_value_length,
|
||||
heads,
|
||||
padded_head_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
launch_pad_add_transform_0213(value_pad_ptr,
|
||||
(T*)value.data_ptr(),
|
||||
value.size(0),
|
||||
@ -1177,7 +1196,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
|
||||
key_value_length,
|
||||
heads,
|
||||
padded_head_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return {
|
||||
at::from_blob(
|
||||
workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()),
|
||||
@ -1210,7 +1229,7 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input,
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -1219,7 +1238,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
|
||||
at::Tensor& weight,
|
||||
bool async_op,
|
||||
at::Tensor& q_scale,
|
||||
bool q_int8)
|
||||
bool q_int8,
|
||||
bool transposed_mode)
|
||||
{
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(input.options().dtype())
|
||||
@ -1229,7 +1249,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
|
||||
int out_size = q_int8 ? weight.size(0) : weight.size(1);
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
|
||||
T* workspace = (T*)Context::Instance().GetWorkSpace();
|
||||
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
|
||||
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
|
||||
if (q_int8) {
|
||||
quantized_gemm<T>(output.data_ptr(),
|
||||
@ -1242,12 +1262,12 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
|
||||
} else {
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(),
|
||||
Context::Instance().GetCurrentStream(async_op));
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream(async_op));
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
weight.size(1),
|
||||
weight.size(transposed_mode ? 0 : 1),
|
||||
bsz,
|
||||
input.size(2),
|
||||
&alpha,
|
||||
@ -1300,11 +1320,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
|
||||
at::Tensor& q_scale,
|
||||
at::Tensor& q_scale1,
|
||||
bool q_int8,
|
||||
ActivationFuncType act_func_type)
|
||||
ActivationFuncType act_func_type,
|
||||
bool transposed_mode)
|
||||
{
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
T* inp_norm =
|
||||
(T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output);
|
||||
T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input) +
|
||||
torch::numel(output);
|
||||
T* intermediate = inp_norm + torch::numel(input);
|
||||
|
||||
if (mlp_after_attn) {
|
||||
@ -1317,7 +1338,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
|
||||
epsilon,
|
||||
bsz,
|
||||
input.size(2),
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
} else {
|
||||
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
|
||||
}
|
||||
@ -1327,12 +1348,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
|
||||
} else {
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(),
|
||||
Context::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
weight.size(1),
|
||||
weight.size(transposed_mode ? 0 : 1),
|
||||
bsz,
|
||||
input.size(2),
|
||||
&alpha,
|
||||
@ -1349,15 +1370,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
|
||||
if (act_func_type == ActivationFuncType::GELU) {
|
||||
launch_bias_gelu(intermediate,
|
||||
(T*)bias.data_ptr(),
|
||||
q_int8 ? weight.size(0) : weight.size(1),
|
||||
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
} else if (act_func_type == ActivationFuncType::ReLU) {
|
||||
launch_bias_relu(intermediate,
|
||||
(T*)bias.data_ptr(),
|
||||
q_int8 ? weight.size(0) : weight.size(1),
|
||||
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
}
|
||||
|
||||
if (q_int8) {
|
||||
@ -1371,14 +1392,14 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
|
||||
} else {
|
||||
float alpha = (T)1.0;
|
||||
float gemm_beta = (T)0.0;
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(),
|
||||
Context::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
weight1.size(1),
|
||||
weight1.size(transposed_mode ? 0 : 1),
|
||||
bsz,
|
||||
weight1.size(0),
|
||||
weight1.size(transposed_mode ? 1 : 0),
|
||||
&alpha,
|
||||
&gemm_beta,
|
||||
(T*)weight1.data_ptr(),
|
||||
@ -1409,7 +1430,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
|
||||
at::Tensor& q_scale,
|
||||
at::Tensor& q_scale1,
|
||||
bool q_int8,
|
||||
int activation_type)
|
||||
int activation_type,
|
||||
bool transposed_mode)
|
||||
{
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(input.options().dtype())
|
||||
@ -1417,10 +1439,11 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
|
||||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
|
||||
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
|
||||
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
|
||||
{input.size(0), input.size(1), out_size},
|
||||
options);
|
||||
int out_size = (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1);
|
||||
auto output =
|
||||
at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input),
|
||||
{input.size(0), input.size(1), out_size},
|
||||
options);
|
||||
int bsz = input.size(0) * input.size(1);
|
||||
|
||||
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
|
||||
@ -1439,7 +1462,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
|
||||
q_scale,
|
||||
q_scale1,
|
||||
q_int8,
|
||||
act_func_type);
|
||||
act_func_type,
|
||||
transposed_mode);
|
||||
|
||||
return {output, res_add};
|
||||
}
|
||||
@ -1475,7 +1499,7 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
return {output, residual_add};
|
||||
}
|
||||
@ -1490,7 +1514,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
const float epsilon,
|
||||
bool preLayerNorm,
|
||||
bool q_int8,
|
||||
bool async_op)
|
||||
bool async_op,
|
||||
bool transposed_mode)
|
||||
{
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(input.options().dtype())
|
||||
@ -1498,9 +1523,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
|
||||
int intm_dim = q_int8 ? weight.size(0) : weight.size(1);
|
||||
int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
|
||||
|
||||
// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
|
||||
// auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() +
|
||||
// torch::numel(input),
|
||||
// {input.size(0), input.size(1), out_size},
|
||||
// options);
|
||||
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
|
||||
@ -1519,10 +1545,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
bsz,
|
||||
input.size(2));
|
||||
} else {
|
||||
cublasSetStream(Context::Instance().GetCublasHandle(),
|
||||
Context::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
CUBLAS_OP_N,
|
||||
intm_dim,
|
||||
bsz,
|
||||
@ -1542,9 +1568,9 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
(T*)bias.data_ptr(),
|
||||
intm_dim,
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
|
||||
int out_size = (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1);
|
||||
auto output = at::empty({input.size(0), input.size(1), out_size}, options);
|
||||
if (q_int8) {
|
||||
quantized_gemm<T>(output.data_ptr(),
|
||||
@ -1555,8 +1581,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
bsz,
|
||||
input.size(2));
|
||||
} else {
|
||||
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
|
||||
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
CUBLAS_OP_N,
|
||||
out_size,
|
||||
bsz,
|
||||
@ -1572,8 +1598,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
#endif
|
||||
}
|
||||
// cudaEventRecord(Context::Instance().GetCompEvent(2),
|
||||
// Context::Instance().GetCurrentStream(true));
|
||||
// cudaEventRecord(InferenceContext::Instance().GetCompEvent(2),
|
||||
// InferenceContext::Instance().GetCurrentStream(true));
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -1600,7 +1626,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
|
||||
hidden_size,
|
||||
mp_size,
|
||||
preln,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
else
|
||||
launch_gptj_residual_add<T>(
|
||||
static_cast<T*>(residual.data_ptr()),
|
||||
@ -1611,7 +1637,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
|
||||
hidden_size,
|
||||
bsz,
|
||||
mp_size,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
return residual;
|
||||
}
|
||||
|
||||
@ -1641,8 +1667,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
|
||||
bsz,
|
||||
rotate_half,
|
||||
rotate_every_two,
|
||||
Context::Instance().GetCurrentStream(),
|
||||
Context::Instance().GetMaxTokenLenght());
|
||||
InferenceContext::Instance().GetCurrentStream(),
|
||||
InferenceContext::Instance().GetMaxTokenLenght());
|
||||
else
|
||||
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
|
||||
(__half*)key_cont.data_ptr(),
|
||||
@ -1654,8 +1680,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
|
||||
bsz,
|
||||
rotate_half,
|
||||
rotate_every_two,
|
||||
Context::Instance().GetCurrentStream(),
|
||||
Context::Instance().GetMaxTokenLenght());
|
||||
InferenceContext::Instance().GetCurrentStream(),
|
||||
InferenceContext::Instance().GetMaxTokenLenght());
|
||||
return {query_cont, key_cont};
|
||||
}
|
||||
|
||||
@ -1684,7 +1710,7 @@ at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
|
||||
(T*)bias.data_ptr(),
|
||||
weight.size(1),
|
||||
bsz,
|
||||
Context::Instance().GetCurrentStream());
|
||||
InferenceContext::Instance().GetCurrentStream());
|
||||
|
||||
return output;
|
||||
}
|
||||
@ -1693,7 +1719,7 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
|
||||
{
|
||||
int M = moe_res.size(0) * moe_res.size(1);
|
||||
int N = moe_res.size(2);
|
||||
Context::Instance().SynchComm();
|
||||
InferenceContext::Instance().SynchComm();
|
||||
if (moe_res.scalar_type() == at::kFloat) {
|
||||
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
|
||||
(float*)coef.data_ptr(),
|
||||
@ -1712,6 +1738,10 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
|
||||
return output;
|
||||
}
|
||||
|
||||
void ds_release_workspace() { InferenceContext::Instance().release_workspace(); }
|
||||
|
||||
bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); }
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
|
||||
@ -1791,4 +1821,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
&allocate_workspace<__half>,
|
||||
"DeepSpeed memory allocation for GPT inference with fp16 (CUDA)");
|
||||
m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks");
|
||||
m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace");
|
||||
m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace");
|
||||
}
|
||||
|
@ -46,17 +46,20 @@ inline int DS_GET_BLOCKS(const int N)
|
||||
1);
|
||||
}
|
||||
|
||||
class Context {
|
||||
class InferenceContext {
|
||||
public:
|
||||
Context()
|
||||
InferenceContext()
|
||||
: _workspace(nullptr),
|
||||
_seed(42),
|
||||
_curr_offset(0),
|
||||
_stream(0),
|
||||
_free_memory_size(0),
|
||||
_num_tokens(1),
|
||||
_attention_unfused_workspace_offset(0)
|
||||
_attention_unfused_workspace_offset(0),
|
||||
_workSpaceSize(0)
|
||||
{
|
||||
_workSpaceSize = 0;
|
||||
_workspace = 0;
|
||||
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
|
||||
auto message = std::string("Fail to create cublas handle.");
|
||||
std::cerr << message << std::endl;
|
||||
@ -71,7 +74,7 @@ public:
|
||||
cudaEventCreate(&_comm_event);
|
||||
}
|
||||
|
||||
virtual ~Context()
|
||||
virtual ~InferenceContext()
|
||||
{
|
||||
cublasDestroy(_cublasHandle);
|
||||
cudaFree(_workspace);
|
||||
@ -81,9 +84,9 @@ public:
|
||||
cudaEventDestroy(_comm_event);
|
||||
}
|
||||
|
||||
static Context& Instance()
|
||||
static InferenceContext& Instance()
|
||||
{
|
||||
static Context _ctx;
|
||||
static InferenceContext _ctx;
|
||||
return _ctx;
|
||||
}
|
||||
|
||||
@ -96,7 +99,8 @@ public:
|
||||
const bool& external_cache,
|
||||
const size_t& elem_size,
|
||||
const unsigned& rank,
|
||||
unsigned max_out_tokens)
|
||||
unsigned max_out_tokens,
|
||||
unsigned min_out_tokens)
|
||||
{
|
||||
size_t total_size;
|
||||
if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }
|
||||
@ -107,9 +111,9 @@ public:
|
||||
const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128);
|
||||
const int effective_head_size = (head_size > 128) ? head_size : padded_head_size;
|
||||
|
||||
size_t activation_size = 16 * (num_heads * effective_head_size) * batch_size;
|
||||
size_t activation_size = 10 * (num_heads * effective_head_size) * batch_size;
|
||||
// Other sequence length dimension is added when the final workSpaceSize is calculated
|
||||
size_t temp_size = batch_size * num_heads * max_out_tokens * 2;
|
||||
size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens;
|
||||
size_t cache_size =
|
||||
num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2;
|
||||
size_t minimal_requirements =
|
||||
@ -129,18 +133,16 @@ public:
|
||||
: (activation_size + temp_size + cache_size))) *
|
||||
_max_seq_len * elem_size;
|
||||
temp_size *= _max_seq_len * elem_size;
|
||||
if (rank == 0 && !_workspace)
|
||||
|
||||
if (_max_seq_len < min_out_tokens) {
|
||||
printf(
|
||||
"------------------------------------------------------\n"
|
||||
"Free memory : %f (GigaBytes) \n"
|
||||
"Total memory: %f (GigaBytes) \n"
|
||||
"Requested memory: %f (GigaBytes) \n"
|
||||
"Setting maximum total tokens (input + output) to %lu \n"
|
||||
"------------------------------------------------------\n",
|
||||
(float)_free_memory_size / GIGABYTE,
|
||||
(float)total_size / GIGABYTE,
|
||||
(float)workSpaceSize / GIGABYTE,
|
||||
_max_seq_len);
|
||||
"Allocatable workspace available (%d tokens) is less than minimum requested "
|
||||
"workspace (%d tokens)\n",
|
||||
_max_seq_len,
|
||||
min_out_tokens);
|
||||
throw std::runtime_error("Workspace can't be allocated, not enough memory");
|
||||
}
|
||||
|
||||
if (!_workspace) {
|
||||
assert(_workspace == nullptr);
|
||||
cudaMalloc(&_workspace, workSpaceSize);
|
||||
@ -148,6 +150,20 @@ public:
|
||||
cudaFree(_workspace);
|
||||
cudaMalloc(&_workspace, workSpaceSize);
|
||||
}
|
||||
if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize))
|
||||
printf(
|
||||
"------------------------------------------------------\n"
|
||||
"Free memory : %f (GigaBytes) \n"
|
||||
"Total memory: %f (GigaBytes) \n"
|
||||
"Requested memory: %f (GigaBytes) \n"
|
||||
"Setting maximum total tokens (input + output) to %lu \n"
|
||||
"WorkSpace: %p \n"
|
||||
"------------------------------------------------------\n",
|
||||
(float)_free_memory_size / GIGABYTE,
|
||||
(float)total_size / GIGABYTE,
|
||||
(float)workSpaceSize / GIGABYTE,
|
||||
_max_seq_len,
|
||||
_workspace);
|
||||
|
||||
if (!_workspace) {
|
||||
printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
|
||||
@ -203,6 +219,17 @@ public:
|
||||
return stream;
|
||||
}
|
||||
|
||||
void release_workspace()
|
||||
{
|
||||
cudaFree(_workspace);
|
||||
_workspace = nullptr;
|
||||
}
|
||||
bool retake_workspace()
|
||||
{
|
||||
if (_workspace != nullptr || _workSpaceSize == 0) return true;
|
||||
cudaMalloc(&_workspace, _workSpaceSize);
|
||||
return _workspace != nullptr;
|
||||
}
|
||||
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
|
||||
|
||||
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
|
||||
|
@ -17,6 +17,7 @@ from . import module_inject
|
||||
|
||||
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
|
||||
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
|
||||
from .runtime.hybrid_engine import DeepSpeedHybridEngine
|
||||
from .runtime.pipe.engine import PipelineEngine
|
||||
from .inference.engine import InferenceEngine
|
||||
from .inference.config import DeepSpeedInferenceConfig
|
||||
@ -26,7 +27,7 @@ from .runtime.activation_checkpointing import checkpointing
|
||||
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
|
||||
from .module_inject import replace_transformer_layer, revert_transformer_layer
|
||||
|
||||
from .utils import log_dist, OnDevice
|
||||
from .utils import log_dist, OnDevice, logger
|
||||
from .comm.comm import init_distributed
|
||||
|
||||
from .runtime import zero
|
||||
@ -118,31 +119,66 @@ def initialize(args=None,
|
||||
|
||||
assert model is not None, "deepspeed.initialize requires a model"
|
||||
|
||||
# Set config using config_params for backwards compat
|
||||
if config is None and config_params is not None:
|
||||
config = config_params
|
||||
|
||||
# Check for deepscale_config for backwards compat
|
||||
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
|
||||
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
|
||||
if hasattr(args, "deepspeed_config"):
|
||||
assert (args.deepspeed_config is
|
||||
None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
|
||||
args.deepspeed_config = args.deepscale_config
|
||||
args.deepscale_config = None
|
||||
|
||||
# Check that we have only one config passed
|
||||
if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None:
|
||||
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
|
||||
config = args.deepspeed_config
|
||||
assert config != None, "DeepSpeed requires --deepspeed_config to specify configuration file"
|
||||
|
||||
if not isinstance(model, PipelineModule):
|
||||
engine = DeepSpeedEngine(args=args,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
model_parameters=model_parameters,
|
||||
training_data=training_data,
|
||||
lr_scheduler=lr_scheduler,
|
||||
mpu=mpu,
|
||||
dist_init_required=dist_init_required,
|
||||
collate_fn=collate_fn,
|
||||
config=config,
|
||||
config_params=config_params)
|
||||
config_class = DeepSpeedConfig(config, mpu)
|
||||
if config_class.hybrid_engine.enabled:
|
||||
engine = DeepSpeedHybridEngine(args=args,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
model_parameters=model_parameters,
|
||||
training_data=training_data,
|
||||
lr_scheduler=lr_scheduler,
|
||||
mpu=mpu,
|
||||
dist_init_required=dist_init_required,
|
||||
collate_fn=collate_fn,
|
||||
config=config,
|
||||
config_class=config_class)
|
||||
else:
|
||||
engine = DeepSpeedEngine(args=args,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
model_parameters=model_parameters,
|
||||
training_data=training_data,
|
||||
lr_scheduler=lr_scheduler,
|
||||
mpu=mpu,
|
||||
dist_init_required=dist_init_required,
|
||||
collate_fn=collate_fn,
|
||||
config=config,
|
||||
config_class=config_class)
|
||||
else:
|
||||
assert mpu is None, "mpu must be None with pipeline parallelism"
|
||||
mpu = model.mpu()
|
||||
config_class = DeepSpeedConfig(config, mpu)
|
||||
engine = PipelineEngine(args=args,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
model_parameters=model_parameters,
|
||||
training_data=training_data,
|
||||
lr_scheduler=lr_scheduler,
|
||||
mpu=model.mpu(),
|
||||
mpu=mpu,
|
||||
dist_init_required=dist_init_required,
|
||||
collate_fn=collate_fn,
|
||||
config=config,
|
||||
config_params=config_params)
|
||||
config_class=config_class)
|
||||
|
||||
return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
|
||||
return tuple(return_items)
|
||||
|
@ -197,6 +197,11 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
This can be passed through the json config too.
|
||||
"""
|
||||
|
||||
set_empty_params: bool = False
|
||||
"""
|
||||
specifying whether the inference-module is created with empty or real Tensor
|
||||
"""
|
||||
|
||||
save_mp_checkpoint_path: str = None
|
||||
"""
|
||||
The path for which we want to save the loaded model with a checkpoint. This
|
||||
@ -247,6 +252,16 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
to the required token-length required for your use-case.
|
||||
"""
|
||||
|
||||
min_out_tokens: int = Field(1, alias="min_tokens")
|
||||
"""
|
||||
This argument communicates to the runtime the minimum number of tokens you
|
||||
expect you will need to generate. This will cause the runtime to error
|
||||
if it unable to provide this and provide context on the memory pressure
|
||||
rather than seg-faulting or providing corrupted output.
|
||||
"""
|
||||
|
||||
transposed_mode: bool = Field(False, alias="transposed_mode")
|
||||
|
||||
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
|
||||
"""
|
||||
Desired model parallel size, default is 1 meaning no model parallelism.
|
||||
|
@ -65,13 +65,18 @@ class DeepSpeedTransformerInference(nn.Module):
|
||||
mlp_extra_grouping)
|
||||
|
||||
device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
|
||||
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
|
||||
requires_grad=False)
|
||||
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
|
||||
requires_grad=False)
|
||||
if self.config.set_empty_params:
|
||||
self.norm_w = None
|
||||
self.norm_b = None
|
||||
else:
|
||||
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
|
||||
requires_grad=False)
|
||||
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
|
||||
requires_grad=False)
|
||||
self.layer_past = None
|
||||
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
|
||||
inference_cuda_module.allocate_workspace_fp16
|
||||
self._alloc_workspace = True
|
||||
|
||||
@classmethod
|
||||
def reset_cache(cls):
|
||||
@ -110,12 +115,14 @@ class DeepSpeedTransformerInference(nn.Module):
|
||||
input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
|
||||
|
||||
# Allocate memory only on first layer forward
|
||||
if self.config.layer_id == 0:
|
||||
if self.config.layer_id == 0 and self._alloc_workspace:
|
||||
self.allocate_workspace(self.config.hidden_size, self.config.heads,
|
||||
input.size()[1],
|
||||
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
|
||||
self.config.bigscience_bloom,
|
||||
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens)
|
||||
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
|
||||
self.config.min_out_tokens)
|
||||
self._alloc_workspace = False
|
||||
|
||||
get_present = (get_present or get_key_value or use_cache)
|
||||
input_mask = input_mask if attention_mask is None else attention_mask
|
||||
|
@ -34,14 +34,11 @@ class BaseTransformerContainer(ABC):
|
||||
self.hidden_size = None
|
||||
self.num_attention_heads = None
|
||||
self.mp_size = self.config.tensor_parallel.tp_size
|
||||
self.pre_layer_norm = self.policy.pre_attn_norm
|
||||
self.pre_layer_norm = self.model_config.do_layer_norm_before if \
|
||||
hasattr(self.model_config, 'do_layer_norm_before') else self.policy.pre_attn_norm
|
||||
self.fp16 = False
|
||||
self.attn_linear_layer = self.policy.linear_layer
|
||||
self.mlp_linear_layer = self.policy.linear_layer
|
||||
self.layer_norm_eps = self.model_config.layer_norm_eps if \
|
||||
hasattr(self.model_config, 'layer_norm_eps') else (self.model_config.layer_norm_epsilon if \
|
||||
hasattr(self.model_config, 'layer_norm_epsilon') else self.model_config.layernorm_epsilon if \
|
||||
hasattr(self.model_config, 'layernorm_epsilon') else 1.0e-12)
|
||||
self.return_tuple = self.config.return_tuple
|
||||
self.triangular_masking = True
|
||||
self.local_attention = ((self.model_config.attention_layers[self.layer_id] == "local") if hasattr(
|
||||
@ -51,6 +48,7 @@ class BaseTransformerContainer(ABC):
|
||||
self.training_mp_size = self.config.training_mp_size
|
||||
self.bigscience_bloom = False
|
||||
self.max_out_tokens = self.config.max_out_tokens
|
||||
self.min_out_tokens = self.config.min_out_tokens
|
||||
self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False)
|
||||
self.use_mup = self.policy.use_mup
|
||||
self.return_single_tuple = False
|
||||
@ -75,6 +73,8 @@ class BaseTransformerContainer(ABC):
|
||||
self.input_nw = None
|
||||
self.input_nb = None
|
||||
|
||||
self.mp_group = None
|
||||
|
||||
def create_ds_model_config(self):
|
||||
self.set_hidden_heads(*self.policy.get_hidden_heads())
|
||||
assert self.num_attention_heads % self.mp_size == 0,\
|
||||
@ -84,11 +84,11 @@ class BaseTransformerContainer(ABC):
|
||||
self.ds_model_config = DeepSpeedInferenceConfig(
|
||||
hidden_size=self.hidden_size,
|
||||
heads=self.num_attention_heads,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
layer_norm_eps=self.layernorm_epsilon,
|
||||
fp16=self.fp16,
|
||||
pre_layer_norm=self.pre_layer_norm,
|
||||
mp_size=self.mp_size,
|
||||
q_int8=self.quantize,
|
||||
q_int8=self.quantize if hasattr(self, 'quantize') else False,
|
||||
return_tuple=self.return_tuple,
|
||||
triangular_masking=self.triangular_masking,
|
||||
local_attention=self.local_attention,
|
||||
@ -99,18 +99,24 @@ class BaseTransformerContainer(ABC):
|
||||
training_mp_size=self.training_mp_size,
|
||||
bigscience_bloom=self.bigscience_bloom,
|
||||
max_out_tokens=self.max_out_tokens,
|
||||
min_out_tokens=self.min_out_tokens,
|
||||
scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx,
|
||||
use_mup=self.use_mup,
|
||||
return_single_tuple=self.return_single_tuple,
|
||||
)
|
||||
set_empty_params=self.config.set_empty_params,
|
||||
transposed_mode=self.config.transposed_mode)
|
||||
|
||||
return self.ds_model_config
|
||||
|
||||
def initialize_tensors(self):
|
||||
def initialize_tensors(self, enable_training=False):
|
||||
# Set the tensors from policy (user module) to container (DS module)
|
||||
self.set_attention(*self.policy.attention())
|
||||
self.set_attention(*self.policy.attention(enable_training=enable_training))
|
||||
self.set_mlp(*self.policy.mlp())
|
||||
self.set_layernorm(*self.policy.layernorm())
|
||||
self.set_lora_params(self.policy.get_lora_params())
|
||||
self.q_k_v = self.policy.get_q_k_v()
|
||||
if self.q_k_v is not None:
|
||||
self.set_q_k_v(*self.q_k_v)
|
||||
|
||||
def convert_to_required_dtype(self, dtype):
|
||||
# Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
|
||||
@ -138,9 +144,10 @@ class BaseTransformerContainer(ABC):
|
||||
self.quantize = quantize
|
||||
self.quantizer = quantizer
|
||||
|
||||
def set_hidden_heads(self, hidden_size, num_attention_heads):
|
||||
def set_hidden_heads(self, hidden_size, num_attention_heads, epsilon):
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.layernorm_epsilon = epsilon
|
||||
|
||||
def set_attention(self, qkvw, qkvb, dense_w, dense_b):
|
||||
self.qkvw = qkvw
|
||||
@ -148,6 +155,17 @@ class BaseTransformerContainer(ABC):
|
||||
self.dense_w = dense_w
|
||||
self.dense_b = dense_b
|
||||
|
||||
def set_lora_params(self, lora_params):
|
||||
self.lora_params = lora_params
|
||||
|
||||
def set_q_k_v(self, qw, qb, kw, kb, vw, vb):
|
||||
self.qw = qw
|
||||
self.qb = qb
|
||||
self.kw = kw
|
||||
self.kb = kb
|
||||
self.vw = vw
|
||||
self.vb = vb
|
||||
|
||||
def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
|
||||
self._h4h_w = _h4h_w
|
||||
self._h4h_b = _h4h_b
|
||||
@ -175,33 +193,148 @@ class BaseTransformerContainer(ABC):
|
||||
self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w)
|
||||
self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w)
|
||||
|
||||
def apply_tensor_parallelism(self, mp_replace):
|
||||
def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None):
|
||||
reversed_dim = False
|
||||
if mp_replace is None:
|
||||
from deepspeed.module_inject import ReplaceWithTensorSlicing
|
||||
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=tp_size, out_dim=0, in_dim=1)
|
||||
reversed_dim = True
|
||||
# setup the new Attention module
|
||||
self.attention_qkv_mp(mp_replace)
|
||||
self.attention_o_mp(mp_replace)
|
||||
if self.module.attention.attn_qkvw is None:
|
||||
self.attention_q_k_v_mp(mp_replace, reversed_dim=reversed_dim)
|
||||
else:
|
||||
self.attention_qkv_mp(mp_replace, reversed_dim=reversed_dim)
|
||||
self.attention_o_mp(mp_replace, reversed_dim=reversed_dim)
|
||||
|
||||
# setup the new MLP module
|
||||
self.mlp_inter_mp(mp_replace)
|
||||
self.mlp_output_mp(mp_replace)
|
||||
self.mlp_inter_mp(mp_replace, reversed_dim=reversed_dim)
|
||||
self.mlp_output_mp(mp_replace, reversed_dim=reversed_dim)
|
||||
|
||||
# Apply weight quantization
|
||||
self.apply_weight_quantization()
|
||||
#self.apply_weight_quantization()
|
||||
|
||||
def attention_qkv_mp(self, mp_replace):
|
||||
self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw, self.qkvw)
|
||||
self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb, self.qkvb)
|
||||
def attention_qkv_mp(self, mp_replace, reversed_dim=False):
|
||||
self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw,
|
||||
self.qkvw,
|
||||
int8=reversed_dim)
|
||||
self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb,
|
||||
self.qkvb,
|
||||
int8=reversed_dim)
|
||||
|
||||
def attention_o_mp(self, mp_replace):
|
||||
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, self.dense_w)
|
||||
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob, self.dense_b)
|
||||
def attention_q_k_v_mp(self, mp_replace, reversed_dim=False):
|
||||
self.module.attention.attn_qw = mp_replace.copy(self.module.attention.attn_qw[:self.qw.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self.qw,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
self.module.attention.attn_kw = mp_replace.copy(self.module.attention.attn_kw[:self.qw.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self.kw,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
self.module.attention.attn_vw = mp_replace.copy(self.module.attention.attn_vw[:self.qw.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self.vw,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
self.module.attention.attn_qb = mp_replace.copy(self.module.attention.attn_qb[:self.qw.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self.qb,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
self.module.attention.attn_kb = mp_replace.copy(self.module.attention.attn_kb[:self.qw.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self.kb,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
self.module.attention.attn_vb = mp_replace.copy(self.module.attention.attn_vb[:self.qw.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self.vb,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
|
||||
def mlp_inter_mp(self, mp_replace):
|
||||
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w)
|
||||
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b)
|
||||
def attention_o_mp(self, mp_replace, reversed_dim=False):
|
||||
if reversed_dim:
|
||||
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow[:, :self.dense_w.shape[1] //
|
||||
mp_replace.mp_size],
|
||||
self.dense_w,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
else:
|
||||
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow,
|
||||
self.dense_w,
|
||||
int8=reversed_dim)
|
||||
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob,
|
||||
self.dense_b,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
|
||||
def mlp_output_mp(self, mp_replace):
|
||||
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w)
|
||||
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, self._4hh_b)
|
||||
def mlp_inter_mp(self, mp_replace, reversed_dim=False):
|
||||
if reversed_dim:
|
||||
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w[:self._h4h_w.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self._h4h_w,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b[:self._h4h_w.shape[0] //
|
||||
mp_replace.mp_size],
|
||||
self._h4h_b,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
else:
|
||||
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w, int8=reversed_dim)
|
||||
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b, int8=reversed_dim)
|
||||
|
||||
def mlp_output_mp(self, mp_replace, reversed_dim=False):
|
||||
if reversed_dim:
|
||||
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w[:, :self._4hh_w.shape[1] //
|
||||
mp_replace.mp_size],
|
||||
self._4hh_w,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
else:
|
||||
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w, int8=reversed_dim)
|
||||
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b,
|
||||
self._4hh_b,
|
||||
int8=reversed_dim,
|
||||
allocat_tensor=reversed_dim)
|
||||
|
||||
def release_qkv(self):
|
||||
del self.module.attention.attn_qkvw
|
||||
del self.module.attention.attn_qkvb
|
||||
self.module.attention.attn_qkvw = None
|
||||
self.module.attention.attn_qkvb = None
|
||||
|
||||
qkv_data = [self.module.attention.attn_qw.data, \
|
||||
self.module.attention.attn_qb.data, \
|
||||
self.module.attention.attn_kw.data, \
|
||||
self.module.attention.attn_kb.data, \
|
||||
self.module.attention.attn_vw.data, \
|
||||
self.module.attention.attn_vb.data]
|
||||
for data in qkv_data:
|
||||
del data
|
||||
|
||||
self.module.attention.attn_qw = self.qw
|
||||
self.module.attention.attn_qb = self.qb
|
||||
self.module.attention.attn_kw = self.kw
|
||||
self.module.attention.attn_kb = self.kb
|
||||
self.module.attention.attn_vw = self.vw
|
||||
self.module.attention.attn_vb = self.vb
|
||||
|
||||
def release_memory(self):
|
||||
self.release_qkv()
|
||||
del self.module.attention.attn_ow
|
||||
del self.module.attention.attn_ob
|
||||
self.module.attention.attn_ow = self.dense_w
|
||||
self.module.attention.attn_ob = self.dense_b
|
||||
del self.module.mlp.inter_w
|
||||
del self.module.mlp.inter_b
|
||||
del self.module.mlp.output_w
|
||||
del self.module.mlp.output_b
|
||||
self.module.mlp.inter_w = self._h4h_w
|
||||
self.module.mlp.inter_b = self._h4h_b
|
||||
self.module.mlp.output_w = self._4hh_w
|
||||
self.module.mlp.output_b = self._4hh_b
|
||||
|
||||
def copy_data_to_new_module(self):
|
||||
if self.attn_nw is None:
|
||||
@ -234,3 +367,106 @@ class BaseTransformerContainer(ABC):
|
||||
data = data.reshape(data.shape[-1], data.shape[-2])
|
||||
data.to(get_accelerator().current_device_name())
|
||||
return data
|
||||
|
||||
def reset_qkv_experimental(self):
|
||||
if self.module.attention.attn_qkvw is None:
|
||||
self.module.attention.attn_qkvw = torch.empty(self.qw.shape[0] * 3,
|
||||
self.qw.shape[0],
|
||||
dtype=self.qw.dtype,
|
||||
device=self.qw.device)
|
||||
self.module.attention.attn_qkvb = torch.empty(self.qw.shape[0] * 3,
|
||||
dtype=self.qw.dtype,
|
||||
device=self.qw.device)
|
||||
self.module.attention.attn_qkvw.data[:self.qw.shape[0]] = self.qw.data
|
||||
self.module.attention.attn_qkvb.data[:self.qw.shape[0]] = self.qb.data
|
||||
self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
|
||||
self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
|
||||
self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
|
||||
self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
|
||||
|
||||
qkv_data = [self.qw.data, \
|
||||
self.qb.data, \
|
||||
self.kw.data, \
|
||||
self.kb.data, \
|
||||
self.vw.data, \
|
||||
self.vb.data]
|
||||
|
||||
self.qw.data = self.module.attention.attn_qkvw.data[:self.qw.shape[0]]
|
||||
self.qb.data = self.module.attention.attn_qkvb.data[:self.qw.shape[0]]
|
||||
self.kw.data = self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
|
||||
self.kb.data = self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
|
||||
self.vw.data = self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:]
|
||||
self.vb.data = self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:]
|
||||
|
||||
for data in qkv_data:
|
||||
del data
|
||||
|
||||
def reset_qkv(self):
|
||||
self.qkvw.data[:self.qw.shape[0]] = self.qw.data
|
||||
self.qkvb.data[:self.qw.shape[0]] = self.qb.data
|
||||
self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
|
||||
self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
|
||||
self.qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
|
||||
self.qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
|
||||
|
||||
qkv_data = [self.qw.data, \
|
||||
self.qb.data, \
|
||||
self.kw.data, \
|
||||
self.kb.data, \
|
||||
self.vw.data, \
|
||||
self.vb.data]
|
||||
|
||||
self.qw.data = self.qkvw.data[:self.qw.shape[0]]
|
||||
self.qb.data = self.qkvb.data[:self.qw.shape[0]]
|
||||
self.kw.data = self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
|
||||
self.kb.data = self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
|
||||
self.vw.data = self.qkvw.data[2 * self.qw.shape[0]:]
|
||||
self.vb.data = self.qkvb.data[2 * self.qw.shape[0]:]
|
||||
|
||||
for data in qkv_data:
|
||||
del data
|
||||
|
||||
def set_params_wo_copy(self, Z3_enabled=False):
|
||||
self.module.mlp.attn_nw = self.attn_nw
|
||||
self.module.mlp.attn_nb = self.attn_nb
|
||||
self.module.norm_w = self.input_nw
|
||||
self.module.norm_b = self.input_nb
|
||||
self.module.mlp.inter_w = self._h4h_w
|
||||
self.module.mlp.inter_b = self._h4h_b
|
||||
self.module.mlp.output_w = self._4hh_w
|
||||
self.module.mlp.output_b = self._4hh_b
|
||||
self.module.attention.attn_ow = self.dense_w
|
||||
self.module.attention.attn_ob = self.dense_b
|
||||
if not Z3_enabled or self.q_k_v is None:
|
||||
self.module.attention.attn_qkvw = self.qkvw
|
||||
self.module.attention.attn_qkvb = self.qkvb
|
||||
if self.q_k_v is not None:
|
||||
if Z3_enabled:
|
||||
self.module.attention.attn_qw = self.qw
|
||||
self.module.attention.attn_qb = self.qb
|
||||
self.module.attention.attn_kw = self.kw
|
||||
self.module.attention.attn_kb = self.kb
|
||||
self.module.attention.attn_vw = self.vw
|
||||
self.module.attention.attn_vb = self.vb
|
||||
else:
|
||||
self.qw.data = self.qkvw[:self.qw.shape[0], :]
|
||||
self.qb.data = self.qkvb[:self.qw.shape[0]]
|
||||
self.kw.data = self.qkvw[self.qw.shape[0]:2 * self.qw.shape[0], :]
|
||||
self.kb.data = self.qkvb[self.qw.shape[0]:2 * self.qw.shape[0]]
|
||||
self.vw.data = self.qkvw[self.qw.shape[0] * 2:, :]
|
||||
self.vb.data = self.qkvb[self.qw.shape[0] * 2:]
|
||||
|
||||
def get_lora_params(self):
|
||||
return self.lora_params
|
||||
|
||||
def get_all_params(self):
|
||||
if self.q_k_v is not None:
|
||||
return [
|
||||
self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w,
|
||||
self._4hh_b, self.qw, self.qb, self.kw, self.kb, self.vw, self.vb, self.dense_w, self.dense_b
|
||||
]
|
||||
else:
|
||||
return [
|
||||
self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w,
|
||||
self._4hh_b, self.qkvw, self.qkvb, self.dense_w, self.dense_b
|
||||
]
|
||||
|
@ -44,10 +44,18 @@ class HFBertLayerPolicy(TransformerPolicy):
|
||||
HFBertLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
if self.pre_attn_norm:
|
||||
attention_layernorm = self.client_module.PostAttentionLayerNorm
|
||||
else:
|
||||
attention_layernorm = self.client_module.attention.output.LayerNorm
|
||||
return self.client_module.attention.self.query.weight.shape[1], \
|
||||
self.client_module.attention.self.num_attention_heads
|
||||
self.client_module.attention.self.num_attention_heads, \
|
||||
attention_layernorm.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
qw = self.client_module.attention.self.query.weight
|
||||
qb = self.client_module.attention.self.query.bias
|
||||
kw = self.client_module.attention.self.key.weight
|
||||
@ -55,8 +63,8 @@ class HFBertLayerPolicy(TransformerPolicy):
|
||||
vw = self.client_module.attention.self.value.weight
|
||||
vb = self.client_module.attention.self.value.bias
|
||||
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
|
||||
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
|
||||
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
|
||||
|
||||
return qkvw, \
|
||||
qkvb, \
|
||||
@ -84,3 +92,6 @@ class HFBertLayerPolicy(TransformerPolicy):
|
||||
attention_layernorm.bias, \
|
||||
transformer_layernorm.weight, \
|
||||
transformer_layernorm.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -28,7 +28,7 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
|
||||
self.module.config.scale_attention = self.scale_attention
|
||||
return self.module
|
||||
|
||||
def attention_qkv_mp(self, mp_replace):
|
||||
def attention_qkv_mp(self, mp_replace, reversed_dim=False):
|
||||
self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
|
||||
self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)
|
||||
|
||||
@ -84,9 +84,13 @@ class BLOOMLayerPolicy(TransformerPolicy):
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.self_attention.hidden_size, \
|
||||
self.client_module.self_attention.num_heads
|
||||
self.client_module.self_attention.num_heads, \
|
||||
self.client_module.input_layernorm.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
return self.client_module.self_attention.query_key_value.weight, \
|
||||
self.client_module.self_attention.query_key_value.bias, \
|
||||
self.client_module.self_attention.dense.weight, \
|
||||
@ -103,3 +107,6 @@ class BLOOMLayerPolicy(TransformerPolicy):
|
||||
self.client_module.post_attention_layernorm.bias, \
|
||||
self.client_module.input_layernorm.weight, \
|
||||
self.client_module.input_layernorm.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -40,7 +40,11 @@ class HFCLIPLayerPolicy(TransformerPolicy):
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.self_attn.q_proj.weight.shape[1], \
|
||||
self.client_module.self_attn.num_heads
|
||||
self.client_module.self_attn.num_heads, \
|
||||
self.client_module.layer_norm1.eps
|
||||
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self):
|
||||
qw = self.client_module.self_attn.q_proj.weight
|
||||
@ -69,3 +73,6 @@ class HFCLIPLayerPolicy(TransformerPolicy):
|
||||
self.client_module.layer_norm2.bias, \
|
||||
self.client_module.layer_norm1.weight, \
|
||||
self.client_module.layer_norm1.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -45,9 +45,13 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attention.q_lin.weight.shape[1], \
|
||||
self.client_module.attention.n_heads
|
||||
self.client_module.attention.n_heads, \
|
||||
self.client_module.sa_layer_norm.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
qw = self.client_module.attention.q_lin.weight
|
||||
qb = self.client_module.attention.q_lin.bias
|
||||
kw = self.client_module.attention.k_lin.weight
|
||||
@ -55,8 +59,8 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
|
||||
vw = self.client_module.attention.v_lin.weight
|
||||
vb = self.client_module.attention.v_lin.bias
|
||||
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0))
|
||||
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0))
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
|
||||
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
|
||||
|
||||
return qkvw, \
|
||||
qkvb, \
|
||||
@ -77,3 +81,6 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
|
||||
attention_layernorm.bias, \
|
||||
transformer_layernorm.weight, \
|
||||
transformer_layernorm.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -13,18 +13,18 @@ class MetaTensorContainer(ABC):
|
||||
self.is_meta = False
|
||||
self.ckpt_load_enabled = True
|
||||
|
||||
def initialize_tensors(self):
|
||||
super().initialize_tensors()
|
||||
def initialize_tensors(self, enable_training=False):
|
||||
super().initialize_tensors(enable_training=enable_training)
|
||||
self.is_meta = self.qkvw.is_meta
|
||||
|
||||
def apply_tensor_parallelism(self, mp_replace):
|
||||
def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None):
|
||||
if self.is_meta:
|
||||
if self.qkvb is None:
|
||||
self.module.attention.attn_qkvb = None
|
||||
if self.dense_b is None:
|
||||
self.module.attention.attn_ob = None
|
||||
else:
|
||||
super().apply_tensor_parallelism(mp_replace)
|
||||
super().apply_tensor_parallelism(mp_replace, mp_group, tp_size)
|
||||
|
||||
def copy_data_to_new_module(self):
|
||||
if self.is_meta:
|
||||
|
@ -37,9 +37,13 @@ class HFGPT2LayerPolicy(TransformerPolicy):
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attn.embed_dim, \
|
||||
self.client_module.attn.num_heads
|
||||
self.client_module.attn.num_heads, \
|
||||
self.client_module.ln_1.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
return self.client_module.attn.c_attn.weight, \
|
||||
self.client_module.attn.c_attn.bias, \
|
||||
self.client_module.attn.c_proj.weight, \
|
||||
@ -56,3 +60,6 @@ class HFGPT2LayerPolicy(TransformerPolicy):
|
||||
self.client_module.ln_2.bias, \
|
||||
self.client_module.ln_1.weight, \
|
||||
self.client_module.ln_1.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -71,14 +71,18 @@ class HFGPTJLayerPolicy(TransformerPolicy):
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attn.q_proj.weight.shape[1], \
|
||||
self.client_module.attn.num_attention_heads
|
||||
self.client_module.attn.num_attention_heads, \
|
||||
self.client_module.ln_1.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
qw = self.client_module.attn.q_proj.weight
|
||||
kw = self.client_module.attn.k_proj.weight
|
||||
vw = self.client_module.attn.v_proj.weight
|
||||
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
|
||||
|
||||
return qkvw, \
|
||||
None, \
|
||||
@ -96,3 +100,6 @@ class HFGPTJLayerPolicy(TransformerPolicy):
|
||||
None, \
|
||||
self.client_module.ln_1.weight, \
|
||||
self.client_module.ln_1.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -73,14 +73,18 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attn.attention.q_proj.weight.shape[1], \
|
||||
self.client_module.attn.attention.num_heads
|
||||
self.client_module.attn.attention.num_heads, \
|
||||
self.client_module.ln_1.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
qw = self.client_module.attn.attention.q_proj.weight
|
||||
kw = self.client_module.attn.attention.k_proj.weight
|
||||
vw = self.client_module.attn.attention.v_proj.weight
|
||||
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
|
||||
|
||||
return qkvw, \
|
||||
None, \
|
||||
@ -98,3 +102,6 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
|
||||
self.client_module.ln_2.bias, \
|
||||
self.client_module.ln_1.weight, \
|
||||
self.client_module.ln_1.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -92,9 +92,13 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
|
||||
attention = self.client_module.self_attention
|
||||
|
||||
return self.client_module.attention.query_key_value.weight.shape[1], \
|
||||
self.client_module.attention.num_attention_heads
|
||||
self.client_module.attention.num_attention_heads, \
|
||||
self.client_module.input_layernorm.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
if GPTNEOXLayerPolicy.version == 0:
|
||||
attention = self.client_module.attention
|
||||
else:
|
||||
@ -116,3 +120,6 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
|
||||
self.client_module.post_attention_layernorm.bias, \
|
||||
self.client_module.input_layernorm.weight, \
|
||||
self.client_module.input_layernorm.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -56,9 +56,13 @@ class MegatronLayerPolicy(TransformerPolicy):
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.attention.query_key_value.weight.shape[1], \
|
||||
self.client_module.attention.num_attention_heads
|
||||
self.client_module.attention.num_attention_heads, \
|
||||
self.client_module.input_layernorm.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return None
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
if self.inference:
|
||||
if MegatronLayerPolicy.version == 0:
|
||||
attention = self.client_module.attention
|
||||
@ -106,3 +110,6 @@ class MegatronLayerPolicy(TransformerPolicy):
|
||||
self.client_module.post_attention_layernorm.bias, \
|
||||
self.client_module.input_layernorm.weight, \
|
||||
self.client_module.input_layernorm.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
return []
|
||||
|
@ -12,6 +12,7 @@ from ..policy import TransformerPolicy
|
||||
from ..policy import transformer_param_names
|
||||
from ..policy import maybe_copy
|
||||
from ..policy import maybe_copy_qkv
|
||||
from ..policy import maybe_get_lora
|
||||
from deepspeed.utils.types import ActivationFuncType
|
||||
|
||||
|
||||
@ -77,20 +78,26 @@ class HFOPTLayerPolicy(TransformerPolicy):
|
||||
pre_attn_norm=True,
|
||||
use_load_prefix=use_load_prefix)
|
||||
self.client_module = client_module
|
||||
|
||||
try:
|
||||
import transformers
|
||||
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
|
||||
if isinstance(TransformerPolicy.hf_model_config, transformers.models.opt.configuration_opt.OPTConfig):
|
||||
self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before
|
||||
except:
|
||||
HFOPTLayerPolicy._orig_layer_class = None
|
||||
|
||||
def get_hidden_heads(self):
|
||||
return self.client_module.self_attn.embed_dim, \
|
||||
self.client_module.self_attn.num_heads
|
||||
self.client_module.self_attn.num_heads, \
|
||||
self.client_module.self_attn_layer_norm.eps
|
||||
|
||||
def attention(self):
|
||||
def get_q_k_v(self):
|
||||
return self.client_module.self_attn.q_proj.weight, \
|
||||
self.client_module.self_attn.q_proj.bias, \
|
||||
self.client_module.self_attn.k_proj.weight, \
|
||||
self.client_module.self_attn.k_proj.bias, \
|
||||
self.client_module.self_attn.v_proj.weight, \
|
||||
self.client_module.self_attn.v_proj.bias
|
||||
|
||||
def attention(self, enable_training=False):
|
||||
qw = self.client_module.self_attn.q_proj.weight
|
||||
qb = self.client_module.self_attn.q_proj.bias
|
||||
|
||||
@ -100,9 +107,8 @@ class HFOPTLayerPolicy(TransformerPolicy):
|
||||
vw = self.client_module.self_attn.v_proj.weight
|
||||
vb = self.client_module.self_attn.v_proj.bias
|
||||
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
|
||||
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)
|
||||
|
||||
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
|
||||
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
|
||||
return qkvw, \
|
||||
qkvb, \
|
||||
self.client_module.self_attn.out_proj.weight, \
|
||||
@ -119,3 +125,16 @@ class HFOPTLayerPolicy(TransformerPolicy):
|
||||
self.client_module.final_layer_norm.bias, \
|
||||
self.client_module.self_attn_layer_norm.weight, \
|
||||
self.client_module.self_attn_layer_norm.bias
|
||||
|
||||
def get_lora_params(self):
|
||||
all_lora_params = []
|
||||
for p in [
|
||||
self.client_module.fc1, \
|
||||
self.client_module.fc2, \
|
||||
self.client_module.self_attn.q_proj, \
|
||||
self.client_module.self_attn.k_proj, \
|
||||
self.client_module.self_attn.v_proj, \
|
||||
self.client_module.self_attn.out_proj, \
|
||||
]:
|
||||
all_lora_params.append(maybe_get_lora(p))
|
||||
return all_lora_params
|
||||
|
@ -55,22 +55,34 @@ class LinearLayer(nn.Module):
|
||||
|
||||
class Normalize(nn.Module):
|
||||
|
||||
def __init__(self, dim, dtype=torch.float, eps=1e-5):
|
||||
def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None):
|
||||
super(Normalize, self).__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
|
||||
self.weight = self.norm.weight
|
||||
self.bias = self.norm.bias
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
else:
|
||||
self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
|
||||
self.weight = self.norm.weight
|
||||
self.bias = self.norm.bias
|
||||
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, input):
|
||||
return self.norm(input)
|
||||
return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps)
|
||||
|
||||
|
||||
class EmbeddingLayer(nn.Module):
|
||||
|
||||
def __init__(self, weight_shape, dtype=torch.half):
|
||||
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
|
||||
super(EmbeddingLayer, self).__init__()
|
||||
self.weight = Parameter(
|
||||
torch.empty(weight_shape[0], weight_shape[1], dtype=dtype, device=get_accelerator().current_device_name()))
|
||||
if weight is None:
|
||||
self.weight = Parameter(
|
||||
torch.empty(weight_shape[0],
|
||||
weight_shape[1],
|
||||
dtype=dtype,
|
||||
device=get_accelerator().current_device_name()))
|
||||
else:
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, input):
|
||||
return F.embedding(input, self.weight)
|
||||
@ -81,11 +93,11 @@ class OPTEmbedding(EmbeddingLayer):
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_shape):
|
||||
def __init__(self, weight_shape=None, weight=None, bias=None):
|
||||
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
||||
self.offset = 2
|
||||
super().__init__(weight_shape)
|
||||
super().__init__(weight_shape, weight=weight)
|
||||
|
||||
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
|
@ -72,7 +72,7 @@ class TransformerPolicy(DSPolicy):
|
||||
self.split_qkv = split_qkv
|
||||
|
||||
@abstractmethod
|
||||
def attention(self):
|
||||
def attention(self, enable_training=False):
|
||||
"""
|
||||
Returns attention qkv and dense parameters
|
||||
weight: (3*hidden, hidden) and (hidden, hidden)
|
||||
@ -80,6 +80,13 @@ class TransformerPolicy(DSPolicy):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_q_k_v(self):
|
||||
"""
|
||||
return all q,k,v parameters without merging them together
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_hidden_heads(self):
|
||||
"""
|
||||
@ -105,6 +112,14 @@ class TransformerPolicy(DSPolicy):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_lora_params(self):
|
||||
"""
|
||||
Returns lora parameters used in transformer layer
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO (lekurile): This function exists in base container as well, consolidate as some point
|
||||
def transpose(data):
|
||||
@ -189,3 +204,19 @@ def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names
|
||||
dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \
|
||||
transpose(qkv_data)), int8=weight_quantizer.q_int8)
|
||||
setattr(module, dst_name, dst)
|
||||
|
||||
|
||||
def pack_lora_weights(p):
|
||||
return [
|
||||
p.lora_right_weight, \
|
||||
p.lora_left_weight, \
|
||||
p.lora_scaling
|
||||
]
|
||||
|
||||
|
||||
def maybe_get_lora(p):
|
||||
if hasattr(p, 'lora_right_weight'):
|
||||
lora_param = pack_lora_weights(p)
|
||||
else:
|
||||
lora_param = []
|
||||
return lora_param
|
||||
|
@ -87,10 +87,12 @@ class ReplaceWithTensorSlicing:
|
||||
dst.scale = src.scale
|
||||
return dst
|
||||
|
||||
def copy(self, dst, src, int8=False):
|
||||
def copy(self, dst, src, int8=False, allocat_tensor=False):
|
||||
if src is None:
|
||||
return src
|
||||
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
|
||||
if allocat_tensor:
|
||||
dst = torch.empty_like(dst)
|
||||
outer_dim = 0 if int8 else 1
|
||||
inner_dim = 1 if int8 else 0
|
||||
src_shape = src.shape
|
||||
@ -102,21 +104,21 @@ class ReplaceWithTensorSlicing:
|
||||
else:
|
||||
if src_shape[inner_dim] != dst_shape[self.in_dim]:
|
||||
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
|
||||
weight_split = torch.split(src, dst_shape[self.in_dim], dim=inner_dim)[self.gpu_index].contiguous()
|
||||
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
|
||||
src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
|
||||
else:
|
||||
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
|
||||
weight_split = torch.split(src.data, dst_shape[self.out_dim],
|
||||
dim=outer_dim)[self.gpu_index].contiguous()
|
||||
dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(weight_split.shape)
|
||||
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
|
||||
src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
|
||||
else:
|
||||
if src_shape[0] == dst_shape[0]:
|
||||
dst.data.copy_(src)
|
||||
dst = src
|
||||
else:
|
||||
bias_split = torch.split(src.data, dst_shape[-1])[self.gpu_index].contiguous()
|
||||
dst.data.copy_(bias_split)
|
||||
dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
|
||||
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
|
||||
if hasattr(src, 'scale'):
|
||||
dst.scale = src.scale
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
|
@ -69,10 +69,13 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
||||
training_mp_size=1,
|
||||
bigscience_bloom=False,
|
||||
max_out_tokens=1024,
|
||||
min_out_tokens=1,
|
||||
enable_qkv_quantization=False,
|
||||
use_mup=False,
|
||||
scale_attn_by_inverse_layer_idx=False,
|
||||
return_single_tuple=False):
|
||||
return_single_tuple=False,
|
||||
set_empty_params=False,
|
||||
transposed_mode=False):
|
||||
super(DeepSpeedInferenceConfig,
|
||||
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
|
||||
num_hidden_layers)
|
||||
@ -97,10 +100,13 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
||||
self.training_mp_size = training_mp_size
|
||||
self.bigscience_bloom = bigscience_bloom
|
||||
self.max_out_tokens = max_out_tokens
|
||||
self.min_out_tokens = min_out_tokens
|
||||
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
||||
self.enable_qkv_quantization = enable_qkv_quantization
|
||||
self.use_mup = use_mup
|
||||
self.return_single_tuple = return_single_tuple
|
||||
self.set_empty_params = set_empty_params
|
||||
self.transposed_mode = transposed_mode
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
|
@ -15,6 +15,7 @@ minus_inf = -10000.0
|
||||
|
||||
class DeepSpeedSelfAttention(nn.Module):
|
||||
num_layers = 0
|
||||
_qkv_buffers = []
|
||||
|
||||
def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1):
|
||||
super(DeepSpeedSelfAttention, self).__init__()
|
||||
@ -24,23 +25,35 @@ class DeepSpeedSelfAttention(nn.Module):
|
||||
self.config.layer_id = DeepSpeedSelfAttention.num_layers
|
||||
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
|
||||
device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
|
||||
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
|
||||
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
|
||||
qkv_size_per_partition,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
out_size_per_partition = self.config.hidden_size // self.config.mp_size
|
||||
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
|
||||
self.config.hidden_size,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
if self.config.set_empty_params:
|
||||
self.attn_qw = None
|
||||
self.attn_qb = None
|
||||
self.attn_kw = None
|
||||
self.attn_kb = None
|
||||
self.attn_vw = None
|
||||
self.attn_vb = None
|
||||
self.attn_qkvw = None
|
||||
self.attn_qkvb = None
|
||||
self.attn_ow = None
|
||||
self.attn_ob = None
|
||||
else:
|
||||
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
|
||||
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
|
||||
qkv_size_per_partition,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
out_size_per_partition = self.config.hidden_size // self.config.mp_size
|
||||
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
|
||||
self.config.hidden_size,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
|
||||
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
|
||||
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
|
||||
@ -65,6 +78,14 @@ class DeepSpeedSelfAttention(nn.Module):
|
||||
self.score_context_func = SoftmaxContextOp(config)
|
||||
self.linear_func = LinearOp(config)
|
||||
self.vector_matmul_func = VectorMatMulOp(config)
|
||||
if len(DeepSpeedSelfAttention._qkv_buffers) == 0:
|
||||
DeepSpeedSelfAttention._qkv_buffers = [
|
||||
torch.empty(self.hidden_size_per_partition * 3,
|
||||
self.config.hidden_size,
|
||||
dtype=data_type_fp,
|
||||
device=device),
|
||||
torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device)
|
||||
]
|
||||
|
||||
def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
|
||||
if isinstance(qkv_out, list):
|
||||
@ -89,6 +110,18 @@ class DeepSpeedSelfAttention(nn.Module):
|
||||
context_layer, key_layer, value_layer = attn_key_value
|
||||
return context_layer, key_layer, value_layer
|
||||
|
||||
def _merge_qkv(self):
|
||||
qvkw = DeepSpeedSelfAttention._qkv_buffers[0]
|
||||
qvkw[:self.hidden_size_per_partition, :] = self.attn_qw
|
||||
qvkw[self.hidden_size_per_partition:2 * self.hidden_size_per_partition, :] = self.attn_kw
|
||||
qvkw[2 * self.hidden_size_per_partition:, :] = self.attn_vw
|
||||
if self.attn_qb is not None:
|
||||
qvkb = DeepSpeedSelfAttention._qkv_buffers[1]
|
||||
qvkb[:self.hidden_size_per_partition] = self.attn_qb
|
||||
qvkb[self.hidden_size_per_partition:2 * self.hidden_size_per_partition] = self.attn_kb
|
||||
qvkb[2 * self.hidden_size_per_partition:] = self.attn_vb
|
||||
return DeepSpeedSelfAttention._qkv_buffers
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
input_mask,
|
||||
@ -101,30 +134,33 @@ class DeepSpeedSelfAttention(nn.Module):
|
||||
norm_w=None,
|
||||
norm_b=None,
|
||||
alibi=None):
|
||||
if self.attn_qkvw is None:
|
||||
self._attn_qkvw, self._attn_qkvb = self._merge_qkv()
|
||||
else:
|
||||
self._attn_qkvw = self.attn_qkvw
|
||||
self._attn_qkvb = self.attn_qkvb
|
||||
|
||||
if not self.config.pre_layer_norm:
|
||||
qkv_out = self.linear_func(input=input,
|
||||
weight=self.attn_qkvw,
|
||||
bias=self.attn_qkvb,
|
||||
weight=self._attn_qkvw,
|
||||
bias=self._attn_qkvb,
|
||||
add_bias=self.attn_qkvb is not None,
|
||||
do_flash_attn=False,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_layers=DeepSpeedSelfAttention.num_layers)
|
||||
else:
|
||||
qkv_out = self.qkv_func(input=input,
|
||||
weight=self.attn_qkvw,
|
||||
bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
|
||||
weight=self._attn_qkvw,
|
||||
bias=(self._attn_qkvb if self._attn_qkvb is not None else norm_b),
|
||||
gamma=norm_w,
|
||||
beta=norm_b,
|
||||
add_bias=(self.attn_qkvb is not None),
|
||||
num_layers=DeepSpeedSelfAttention.num_layers,
|
||||
num_heads=self.num_attention_heads_per_partition)
|
||||
|
||||
context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
|
||||
input_mask=input_mask,
|
||||
layer_past=layer_past,
|
||||
alibi=alibi)
|
||||
|
||||
output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
|
||||
|
||||
inp_norm = qkv_out[-1]
|
||||
|
@ -20,25 +20,33 @@ class DeepSpeedMLP(nn.Module):
|
||||
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
|
||||
data_type_fp = torch.half if config.fp16 else torch.float
|
||||
device = get_accelerator().current_device_name()
|
||||
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
|
||||
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
|
||||
intm_size_per_partition,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
|
||||
self.config.hidden_size,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
if self.config.set_empty_params:
|
||||
self.attn_nw = None
|
||||
self.attn_nb = None
|
||||
self.inter_w = None
|
||||
self.inter_b = None
|
||||
self.output_w = None
|
||||
self.output_b = None
|
||||
else:
|
||||
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
|
||||
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
|
||||
intm_size_per_partition,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
|
||||
self.config.hidden_size,
|
||||
dtype=data_type,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
# used for quantization
|
||||
self.q_scales = q_scales
|
||||
@ -74,8 +82,6 @@ class DeepSpeedMLP(nn.Module):
|
||||
final_bias=self.output_b,
|
||||
add_bias=bias is not None,
|
||||
residual_add=residual_add)
|
||||
|
||||
if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
|
||||
dist.all_reduce(residual, group=self.mp_group)
|
||||
|
||||
return residual
|
||||
|
@ -24,5 +24,6 @@ class GELUGemmOp(BaseOp):
|
||||
weight_out: torch.Tensor,
|
||||
async_op: bool = False):
|
||||
output = self.fused_gemm_gelu(input, weight, weight.scale, bias, weight_out, weight_out.scale,
|
||||
self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op)
|
||||
self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op,
|
||||
self.config.transposed_mode)
|
||||
return output
|
||||
|
@ -26,5 +26,6 @@ class LinearOp(BaseOp):
|
||||
num_heads: int,
|
||||
external_cache: bool = None,
|
||||
num_layers: int = None):
|
||||
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads)
|
||||
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
|
||||
self.config.transposed_mode)
|
||||
return qkv_out
|
||||
|
@ -20,8 +20,10 @@ class MLPGemmOp(BaseOp):
|
||||
def forward(self, input: torch.Tensor, residual: torch.Tensor, input_bias: torch.Tensor,
|
||||
weight_interm: torch.Tensor, weight_out: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor,
|
||||
beta: torch.Tensor):
|
||||
output, residual_add = self.mlp_gemm_func(input, residual, input_bias, weight_interm, weight_out, bias, gamma,
|
||||
beta, self.config.epsilon, self.config.pre_layer_norm,
|
||||
self.config.mlp_after_attn, weight_interm.scale, weight_out.scale,
|
||||
self.config.q_int8, self.config.mlp_act_func_type)
|
||||
output, residual_add = self.mlp_gemm_func(
|
||||
input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, self.config.epsilon,
|
||||
self.config.pre_layer_norm, self.config.mlp_after_attn,
|
||||
weight_interm.scale if hasattr(weight_interm, 'scale') else torch.empty(1),
|
||||
weight_out.scale if hasattr(weight_out, 'scale') else torch.empty(1), self.config.q_int8,
|
||||
self.config.mlp_act_func_type, self.config.transposed_mode)
|
||||
return output, residual_add
|
||||
|
@ -28,10 +28,11 @@ class QKVGemmOp(BaseOp):
|
||||
num_layers: int,
|
||||
num_heads: int = None,
|
||||
max_out_tokens: int = None):
|
||||
q_scale = weight.scale
|
||||
q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1)
|
||||
external_cache = self.config.bigscience_bloom
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
q_int8 = self.config.q_int8
|
||||
output = self.qkv_gemm_func(input, weight, q_scale, bias, gamma, beta, self.config.epsilon, add_bias,
|
||||
num_layers, external_cache, self.config.mp_size, rank, q_int8)
|
||||
num_layers, external_cache, self.config.mp_size, rank, q_int8,
|
||||
self.config.transposed_mode)
|
||||
return output
|
||||
|
@ -18,7 +18,7 @@ class VectorMatMulOp(BaseOp):
|
||||
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp32
|
||||
|
||||
def forward(self, input: torch.Tensor, weight: torch.Tensor, async_op: bool = False):
|
||||
q_scale = weight.scale
|
||||
q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1)
|
||||
q_int8 = self.config.q_int8
|
||||
output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8)
|
||||
output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8, self.config.transposed_mode)
|
||||
return output
|
||||
|
@ -31,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig
|
||||
from ..monitor.config import get_monitor_config
|
||||
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
|
||||
from ..git_version_info import version as __version__
|
||||
from ..utils import logger
|
||||
@ -514,6 +515,21 @@ def get_memory_breakdown(param_dict):
|
||||
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
|
||||
|
||||
|
||||
class HybridEngineConfig(DeepSpeedConfigModel):
|
||||
enabled: bool = False
|
||||
max_out_tokens: int = 512
|
||||
inference_tp_size: int = 1
|
||||
release_inference_cache: bool = False
|
||||
pin_parameters: bool = True
|
||||
tp_gather_partition_size: int = 8
|
||||
|
||||
|
||||
def get_hybrid_engine_config(param_dict):
|
||||
hybrid_engine_config_dict = param_dict.get("hybrid_engine", {})
|
||||
hybrid_engine_config = HybridEngineConfig(**hybrid_engine_config_dict)
|
||||
return hybrid_engine_config
|
||||
|
||||
|
||||
def get_eigenvalue_config(param_dict):
|
||||
if get_quantize_enabled(param_dict):
|
||||
param_dict = param_dict[QUANTIZE_TRAINING]
|
||||
@ -816,6 +832,8 @@ class DeepSpeedConfig(object):
|
||||
self.eigenvalue_layer_num,
|
||||
) = get_eigenvalue_config(param_dict)
|
||||
|
||||
self.hybrid_engine = get_hybrid_engine_config(param_dict)
|
||||
|
||||
self.sparse_attention = get_sparse_attention(param_dict)
|
||||
self.pipeline = get_pipeline_config(param_dict)
|
||||
|
||||
|
@ -33,7 +33,7 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
|
||||
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
|
||||
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
|
||||
|
||||
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
|
||||
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
|
||||
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
|
||||
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER
|
||||
|
||||
@ -195,7 +195,7 @@ class DeepSpeedEngine(Module):
|
||||
dist_init_required=None,
|
||||
collate_fn=None,
|
||||
config=None,
|
||||
config_params=None,
|
||||
config_class=None,
|
||||
dont_change_device=False,
|
||||
):
|
||||
super(DeepSpeedEngine, self).__init__()
|
||||
@ -213,6 +213,7 @@ class DeepSpeedEngine(Module):
|
||||
self.gradient_average = True
|
||||
self.warn_unscaled_loss = True
|
||||
self.config = config
|
||||
self._config = config_class
|
||||
self.loaded_checkpoint_mp_world_size = None
|
||||
self.loaded_checkpoint_dp_world_size = None
|
||||
self.enable_backward_allreduce = True
|
||||
@ -242,10 +243,6 @@ class DeepSpeedEngine(Module):
|
||||
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
|
||||
self.param_names = {param: name for name, param in model.named_parameters()}
|
||||
|
||||
# Set config using config_params for backwards compat
|
||||
if self.config is None and config_params is not None:
|
||||
self.config = config_params
|
||||
|
||||
from deepspeed.comm import supported_torch_version
|
||||
# This supported_torch_version check is for torch1.2 compatibility only
|
||||
if supported_torch_version:
|
||||
@ -949,19 +946,8 @@ class DeepSpeedEngine(Module):
|
||||
if hasattr(args, 'local_rank'):
|
||||
args.local_rank = self.local_rank
|
||||
|
||||
if self.config is None:
|
||||
self.config = (args.deepspeed_config if hasattr(args, "deepspeed_config") else None)
|
||||
self._config = DeepSpeedConfig(self.config, mpu)
|
||||
|
||||
# Validate command line arguments
|
||||
def _do_args_sanity_check(self, args):
|
||||
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
|
||||
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
|
||||
if hasattr(args, "deepspeed_config"):
|
||||
assert (args.deepspeed_config is
|
||||
None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
|
||||
args.deepspeed_config = args.deepscale_config
|
||||
|
||||
assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \
|
||||
"variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \
|
||||
"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
|
||||
@ -975,10 +961,6 @@ class DeepSpeedEngine(Module):
|
||||
env_local_rank == args.local_rank
|
||||
), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
|
||||
|
||||
if self.config is None:
|
||||
assert (hasattr(args, "deepspeed_config") and args.deepspeed_config
|
||||
is not None), "DeepSpeed requires --deepspeed_config to specify configuration file"
|
||||
|
||||
def _is_supported_optimizer(self, optimizer_name):
|
||||
return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None)
|
||||
|
||||
|
@ -60,6 +60,7 @@ class LossScalerBase:
|
||||
def backward(self, loss, retain_graph=False):
|
||||
scaled_loss = loss * self.loss_scale
|
||||
scaled_loss.backward(retain_graph=retain_graph)
|
||||
# print(f'LossScalerBackward: {scaled_loss=}')
|
||||
|
||||
|
||||
class LossScaler(LossScalerBase):
|
||||
|
404
deepspeed/runtime/hybrid_engine.py
Normal file
404
deepspeed/runtime/hybrid_engine.py
Normal file
@ -0,0 +1,404 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from deepspeed.inference.config import DeepSpeedInferenceConfig
|
||||
from deepspeed.module_inject.replace_policy import replace_policies
|
||||
from deepspeed.module_inject.utils import policy_to_ds_container
|
||||
from .engine import DeepSpeedEngine
|
||||
from .utils import TLinear, get_inactive_params
|
||||
from deepspeed.runtime.zero import GatheredParameters
|
||||
import time
|
||||
import gc
|
||||
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from torch import nn
|
||||
from deepspeed.utils import logger
|
||||
|
||||
from deepspeed.ops.op_builder import InferenceBuilder
|
||||
|
||||
from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
|
||||
try:
|
||||
import transformers
|
||||
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
|
||||
except:
|
||||
OPTLearnedPositionalEmbedding = None
|
||||
inference_cuda_module = None
|
||||
|
||||
|
||||
class DeepSpeedHybridEngine(DeepSpeedEngine):
|
||||
r"""DeepSpeed engine for training and inference."""
|
||||
inference_mp_group = None
|
||||
|
||||
def __init__(self, args, model, **kwargs):
|
||||
|
||||
super().__init__(args, model, **kwargs)
|
||||
|
||||
# synch seed between all GPUs
|
||||
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
|
||||
dist.broadcast(_rng_state, 0)
|
||||
get_accelerator().set_rng_state(_rng_state.cpu())
|
||||
|
||||
self.Z3_enabled = (self._config.zero_config.stage == 3)
|
||||
self.gather_all_layers = self._config.hybrid_engine.pin_parameters
|
||||
|
||||
# inference containers / fwds
|
||||
self._inference_containers = []
|
||||
self._orig_modules = []
|
||||
self._orig_fwds = []
|
||||
self.create_inference_module()
|
||||
|
||||
# Performance stats
|
||||
self._t_start = None
|
||||
self._total_latency = 0
|
||||
self._iters = 0
|
||||
self._training_start_time = None
|
||||
self._generate_latency = 0
|
||||
self._training_latency = 0
|
||||
self._total_batch_size = None
|
||||
self._gather_latency = 0
|
||||
|
||||
global inference_cuda_module
|
||||
if inference_cuda_module is None:
|
||||
builder = InferenceBuilder()
|
||||
inference_cuda_module = builder.load()
|
||||
|
||||
self.is_lora_fused = False
|
||||
|
||||
def convert_to_linear_transposed(self, model):
|
||||
|
||||
def _replace_linear_layer(r_module, parent_type=None, prev_type=None):
|
||||
for name, child in r_module.named_children():
|
||||
if child.__class__ in [torch.nn.Linear] and \
|
||||
(parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList):
|
||||
setattr(r_module, name, TLinear(child, name))
|
||||
else:
|
||||
_replace_linear_layer(child, type(r_module), prev_type=parent_type)
|
||||
return r_module
|
||||
|
||||
_replace_linear_layer(model)
|
||||
|
||||
def new_inference_container(self, orig_layer, policy_cls, layer_id):
|
||||
policy = policy_cls(orig_layer, inference=True)
|
||||
_container = policy_to_ds_container(
|
||||
policy=policy,
|
||||
config=DeepSpeedInferenceConfig(set_empty_params=True,
|
||||
max_out_tokens=self._config.hybrid_engine.max_out_tokens,
|
||||
min_out_tokens=self._config.hybrid_engine.max_out_tokens,
|
||||
transposed_mode=True),
|
||||
model_config=self.module.config if hasattr(self.module, 'config') else None,
|
||||
layer_id=layer_id,
|
||||
child=orig_layer)
|
||||
_container.set_dtype(self._config.fp16_enabled)
|
||||
|
||||
_container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group)
|
||||
_container.initialize_tensors(enable_training=True)
|
||||
_container.create_ds_model_config()
|
||||
_container.create_module()
|
||||
_container.set_params_wo_copy(Z3_enabled=self.Z3_enabled)
|
||||
return _container
|
||||
|
||||
def populate_all_inference_policies(self):
|
||||
self.inference_policies = {}
|
||||
for plcy in replace_policies:
|
||||
_ = plcy(None)
|
||||
if isinstance(plcy._orig_layer_class, list):
|
||||
for orig_layer_class in plcy._orig_layer_class:
|
||||
self.inference_policies.update({orig_layer_class: (self.new_inference_container, plcy)})
|
||||
elif plcy._orig_layer_class is not None:
|
||||
self.inference_policies.update({plcy._orig_layer_class: (self.new_inference_container, plcy)})
|
||||
self.inference_policies.update({
|
||||
nn.Linear: (LinearLayer, ),
|
||||
nn.Embedding: (EmbeddingLayer, ),
|
||||
nn.LayerNorm: (Normalize, ),
|
||||
OPTLearnedPositionalEmbedding: (OPTEmbedding, )
|
||||
})
|
||||
|
||||
def _fuse_lora(self, params, lora_params):
|
||||
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
|
||||
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
|
||||
if len(lora_params) > 0:
|
||||
lora_right_weight, \
|
||||
lora_left_weight, \
|
||||
lora_scaling = lora_param
|
||||
weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
|
||||
|
||||
def fuse_lora_weight(self):
|
||||
for layer_id in range(len(self.layer_params)):
|
||||
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
|
||||
|
||||
def _unfuse_lora(self, params, lora_params):
|
||||
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
|
||||
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
|
||||
if len(lora_params) > 0:
|
||||
lora_right_weight, \
|
||||
lora_left_weight, \
|
||||
lora_scaling = lora_param
|
||||
weight.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
|
||||
|
||||
def unfuse_lora_weight(self):
|
||||
for layer_id in range(len(self.layer_params)):
|
||||
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
|
||||
|
||||
def unfuse_lora_weight_non_pinned(self):
|
||||
for layer_id in range(len(self.layer_params)):
|
||||
non_active_params = get_inactive_params(self.layer_params[layer_id])
|
||||
non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
|
||||
non_active_params.extend(non_active_lora_params)
|
||||
|
||||
with GatheredParameters(non_active_params):
|
||||
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
|
||||
|
||||
def retake_inference_cache(self):
|
||||
if self._config.hybrid_engine.release_inference_cache:
|
||||
retake_success = inference_cuda_module.retake_workspace()
|
||||
|
||||
if not retake_success:
|
||||
logger.warning("Unable to acquire workspace on first attempt, emtpying cache and retrying.")
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
retake_success = inference_cuda_module.retake_workspace()
|
||||
|
||||
if not retake_success:
|
||||
raise RuntimeError("Unable to retake inference workspace.")
|
||||
|
||||
def generate(self, *inputs, **kwargs):
|
||||
if self._total_batch_size is None:
|
||||
bsz = inputs[0].shape[0] if len(inputs) > 0 else \
|
||||
kwargs['input_ids'].shape[0]
|
||||
self._total_batch_size = bsz * dist.get_world_size()
|
||||
|
||||
self._t0 = time.time()
|
||||
|
||||
if self.Z3_enabled and self.gather_all_layers:
|
||||
if self._config.hybrid_engine.inference_tp_size > 1:
|
||||
non_tp_params = []
|
||||
for other_layer in self._other_layers:
|
||||
non_tp_params.extend(list(other_layer.parameters()))
|
||||
|
||||
partition_size = self._config.hybrid_engine.tp_gather_partition_size
|
||||
|
||||
layer_groups = len(self.layer_params) // partition_size
|
||||
for lg in range(layer_groups):
|
||||
non_active_params = []
|
||||
non_active_lora_params = []
|
||||
for layer_id in range(lg * partition_size, min(len(self.layer_params), (lg + 1) * partition_size),
|
||||
1):
|
||||
non_tp_params.extend(self.layer_params[layer_id][:4])
|
||||
non_active_params.extend(get_inactive_params(self.layer_params[layer_id]))
|
||||
non_active_params.extend(get_inactive_params(self.layer_lora_params[layer_id]))
|
||||
with GatheredParameters(non_active_params):
|
||||
for layer_id in range(lg * partition_size,
|
||||
min(len(self.layer_params), (lg + 1) * partition_size), 1):
|
||||
if len(self.all_lora_params) > 0:
|
||||
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
|
||||
self._inference_containers[layer_id].apply_tensor_parallelism(
|
||||
mp_group=self.mp_group, tp_size=self._config.hybrid_engine.inference_tp_size)
|
||||
|
||||
# TODO(cmikeh2) Evaluate if this can be deferred when release_inference_cache
|
||||
# is enabled.
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
self._gather_latency = time.time() - self._t0
|
||||
|
||||
input_shape = inputs[0].shape if len(inputs) > 0 else \
|
||||
kwargs['input_ids'].shape
|
||||
output = torch.zeros(
|
||||
(input_shape[0] * self._config.hybrid_engine.inference_tp_size, ) + input_shape[1:],
|
||||
dtype=inputs[0].dtype if len(inputs) > 0 else kwargs['input_ids'].dtype,
|
||||
device=inputs[0].device if len(inputs) > 0 else kwargs['input_ids'].device)
|
||||
input_cont = inputs[0].contiguous() if len(inputs) > 0 else kwargs['input_ids'].contiguous()
|
||||
dist.all_gather_into_tensor(output, input_cont, group=self.mp_group)
|
||||
|
||||
if len(inputs) > 0:
|
||||
inputs = (output, )
|
||||
else:
|
||||
kwargs['input_ids'] = output
|
||||
|
||||
self.retake_inference_cache()
|
||||
|
||||
non_active_params = get_inactive_params(non_tp_params)
|
||||
with GatheredParameters(non_active_params):
|
||||
generate_ret_vals = self._generate(*inputs, **kwargs)
|
||||
|
||||
for layer_id in range(len(self.layer_params)):
|
||||
self._inference_containers[layer_id].release_memory()
|
||||
|
||||
rank = dist.get_rank(group=self.mp_group)
|
||||
generate_ret_vals = generate_ret_vals[input_shape[0] * rank:input_shape[0] * (rank + 1)]
|
||||
|
||||
else:
|
||||
non_active_layers = get_inactive_params(self.all_layers_params)
|
||||
non_active_lora_params = get_inactive_params(self.all_lora_params)
|
||||
non_active_layers.extend(non_active_lora_params)
|
||||
with GatheredParameters(non_active_layers):
|
||||
self._gather_latency = time.time() - self._t0
|
||||
|
||||
if len(self.all_lora_params) > 0:
|
||||
self.fuse_lora_weight()
|
||||
|
||||
self.retake_inference_cache()
|
||||
generate_ret_vals = self._generate(*inputs, **kwargs)
|
||||
|
||||
if len(self.all_lora_params) > 0:
|
||||
self.unfuse_lora_weight()
|
||||
else:
|
||||
if len(self.all_lora_params) > 0 and (not self.Z3_enabled):
|
||||
self.fuse_lora_weight()
|
||||
|
||||
self.retake_inference_cache()
|
||||
generate_ret_vals = self._generate(*inputs, **kwargs)
|
||||
|
||||
if len(self.all_lora_params) > 0:
|
||||
if (not self.Z3_enabled):
|
||||
self.unfuse_lora_weight()
|
||||
else:
|
||||
self.unfuse_lora_weight_non_pinned()
|
||||
self.is_lora_fused = False
|
||||
|
||||
if self._config.hybrid_engine.release_inference_cache:
|
||||
inference_cuda_module.release_workspace()
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
self._generate_latency = time.time() - self._t0 - self._gather_latency
|
||||
|
||||
return generate_ret_vals
|
||||
|
||||
def create_inference_containers(self, module, layer_id=0):
|
||||
for name, child in module.named_children():
|
||||
if child.__class__ in self.inference_policies:
|
||||
if self.inference_policies[child.__class__][0] == self.new_inference_container:
|
||||
self._inference_containers.append(self.inference_policies[child.__class__][0](
|
||||
child, self.inference_policies[child.__class__][-1], layer_id))
|
||||
self._orig_modules.append(child)
|
||||
self._orig_fwds.append(child.forward)
|
||||
|
||||
self.layer_params.append(self._inference_containers[layer_id].get_all_params())
|
||||
|
||||
self.lora_params.append(self._inference_containers[layer_id].get_lora_params())
|
||||
self.layer_lora_params.append([])
|
||||
for lora_param in self.lora_params[layer_id]:
|
||||
self.layer_lora_params[layer_id].extend(lora_param[:-1])
|
||||
self.all_lora_params.extend(lora_param[:-1])
|
||||
|
||||
layer_id += 1
|
||||
else:
|
||||
self._other_layers.append(self.inference_policies[child.__class__][0](
|
||||
weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
|
||||
self._orig_modules_others.append(child)
|
||||
self._orig_fwds_others.append(child.forward)
|
||||
else:
|
||||
self.create_inference_containers(child, layer_id=layer_id)
|
||||
|
||||
def create_inference_module(self):
|
||||
self.layer_params = []
|
||||
self.layer_lora_params = []
|
||||
self.lora_params = []
|
||||
self.all_lora_params = []
|
||||
|
||||
self._other_layers = []
|
||||
self._orig_modules_others = []
|
||||
self._orig_fwds_others = []
|
||||
|
||||
if self._config.hybrid_engine.inference_tp_size > 1:
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size
|
||||
num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size
|
||||
for mp_group_id in range(num_mp_groups):
|
||||
ranks = list(
|
||||
range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \
|
||||
(mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \
|
||||
1)
|
||||
)
|
||||
mp_group = dist.new_group(ranks)
|
||||
if global_rank in ranks:
|
||||
self.mp_group = mp_group
|
||||
else:
|
||||
self.mp_group = None
|
||||
self.populate_all_inference_policies()
|
||||
self.all_layers_params = list(self.module.parameters())
|
||||
self.create_inference_containers(self.module)
|
||||
|
||||
self._generate = self.module.generate
|
||||
self.module.generate = self.generate
|
||||
|
||||
self._t0 = time.time()
|
||||
|
||||
def _zero3_forward(self, layer_id):
|
||||
|
||||
def run_forward(*inputs, **kwargs):
|
||||
non_active_params = get_inactive_params(self.layer_params[layer_id])
|
||||
non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
|
||||
non_active_params.extend(non_active_lora_params)
|
||||
|
||||
with GatheredParameters(non_active_params):
|
||||
if len(self.all_lora_params) > 0:
|
||||
# Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory
|
||||
if not self.is_lora_fused:
|
||||
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
|
||||
# Set the is_lora_fused to true when reaching the last layer
|
||||
if layer_id == len(self.layer_params) - 1:
|
||||
self.is_lora_fused = True
|
||||
return self._inference_containers[layer_id].module.forward(*inputs, **kwargs)
|
||||
|
||||
return run_forward
|
||||
|
||||
def eval(self):
|
||||
if self._t_start is not None:
|
||||
latency = time.time() - self._t_start
|
||||
self._total_latency = self._total_latency + latency
|
||||
self._iters = self._iters + 1
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
others = latency - (self._generate_latency + self._training_latency)
|
||||
print(f'|E2E latency={(latency):.2f}s ' + \
|
||||
f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) '
|
||||
f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \
|
||||
f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \
|
||||
f'|Others={others:.2f} ({(others / latency * 100):.2f}%)'
|
||||
f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + \
|
||||
f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}')
|
||||
self._t_start = time.time()
|
||||
self._training_latency = 0
|
||||
super().eval()
|
||||
if len(self._inference_containers) > 0:
|
||||
for i, (orig_module, inference_container) in enumerate(zip(self._orig_modules,
|
||||
self._inference_containers)):
|
||||
if self.Z3_enabled and not self.gather_all_layers:
|
||||
orig_module.forward = self._zero3_forward(i)
|
||||
else:
|
||||
orig_module.forward = inference_container.module.forward
|
||||
|
||||
if not self.Z3_enabled or self.gather_all_layers:
|
||||
for orig_module, inference_layer in zip(self._orig_modules_others, self._other_layers):
|
||||
orig_module.forward = inference_layer.forward
|
||||
if self.Z3_enabled:
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
if self._t_start is None:
|
||||
self._t_start = time.time()
|
||||
|
||||
def train(self, mode=True):
|
||||
if mode and len(self._orig_modules) > 0:
|
||||
for orig_module, orig_fwd in zip(self._orig_modules, self._orig_fwds):
|
||||
orig_module.forward = orig_fwd
|
||||
for orig_module, orig_fwd in zip(self._orig_modules_others, self._orig_fwds_others):
|
||||
orig_module.forward = orig_fwd
|
||||
super().train(mode)
|
||||
if mode:
|
||||
self._training_start_time = time.time()
|
||||
|
||||
def step(self, lr_kwargs=None):
|
||||
super().step(lr_kwargs=lr_kwargs)
|
||||
if(self._inference_containers[0].module.attention.attn_qkvw is not None and \
|
||||
self._inference_containers[0].q_k_v is not None):
|
||||
for inference_container in self._inference_containers:
|
||||
inference_container.reset_qkv()
|
||||
if self._training_start_time is not None:
|
||||
self._training_latency += (time.time() - self._training_start_time)
|
||||
self._training_start_time = time.time()
|
@ -30,6 +30,9 @@ from deepspeed.runtime.constants import PIPE_REPLICATED
|
||||
from numpy import prod
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from deepspeed.module_inject.policy import transpose
|
||||
from torch.nn import functional as F
|
||||
|
||||
torch_memory_reserved = get_accelerator().memory_reserved
|
||||
torch_max_memory_reserved = get_accelerator().max_memory_reserved
|
||||
|
||||
@ -300,6 +303,7 @@ def get_global_norm(norm_list):
|
||||
total_norm = 0.0
|
||||
for norm in norm_list:
|
||||
total_norm += norm**2.0
|
||||
# logger.info(f'norm_list = {norm_list} global = {sqrt(total_norm)}')
|
||||
return sqrt(total_norm)
|
||||
|
||||
|
||||
@ -944,3 +948,28 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align
|
||||
shard_list.append(curr_shard)
|
||||
|
||||
dist.all_gather(shard_list, shard_list[partition_id], dp_process_group[group_id])
|
||||
|
||||
|
||||
class TLinear(torch.nn.Linear):
|
||||
|
||||
def __init__(self, orig_layer, name=""):
|
||||
self.name = name
|
||||
super().__init__(orig_layer.weight.shape[1], orig_layer.weight.shape[0], bias=(orig_layer.bias is not None))
|
||||
self.weight.data = transpose(orig_layer.weight.data)
|
||||
self.bias = orig_layer.bias
|
||||
self._fwd_func = self._fwd_bias_add if self.bias is not None else self._fwd
|
||||
|
||||
def _fwd(self, input):
|
||||
return F.linear(input, self.weight)
|
||||
|
||||
def _fwd_bias_add(self, input):
|
||||
return F.linear(input, self.weight, bias=self.bias)
|
||||
|
||||
def forward(self, input):
|
||||
return self._fwd_func(input)
|
||||
|
||||
|
||||
def get_inactive_params(param_list):
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
return [param for param in param_list if (hasattr(param, 'ds_id') and \
|
||||
param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
|
||||
|
@ -34,7 +34,8 @@ ZeRO optimization should be enabled as:
|
||||
"offload_param": {...},
|
||||
"offload_optimizer": {...},
|
||||
"ignore_unused_parameters": [true|false],
|
||||
"round_robin_gradients": [true|false]
|
||||
"round_robin_gradients": [true|false],
|
||||
"memory_efficient_linear": [true|false]
|
||||
}
|
||||
}
|
||||
"""
|
||||
@ -248,6 +249,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
between optimizer steps) or GPU count (increased parallelism).
|
||||
"""
|
||||
|
||||
memory_efficient_linear: bool = True
|
||||
"""
|
||||
Use memory efficient linear implementation, for Stage 3.
|
||||
"""
|
||||
|
||||
# Validators
|
||||
@validator("overlap_comm")
|
||||
def overlap_comm_valid(cls, field_value, values):
|
||||
|
@ -49,7 +49,6 @@ class LinearFunctionForZeroStage3(torch.autograd.Function):
|
||||
@autocast_custom_fwd
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, weight, bias=None):
|
||||
#print("In ZeRO Linear Function")
|
||||
|
||||
weight_id = id(weight)
|
||||
bias_id = id(bias)
|
||||
|
@ -461,7 +461,8 @@ class DeepSpeedZeRoOffload(object):
|
||||
|
||||
@torch.no_grad()
|
||||
def pre_sub_module_backward_function(self, sub_module):
|
||||
param_coordinator = self.get_param_coordinator(training=sub_module.training)
|
||||
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
|
||||
param_coordinator = self.get_param_coordinator(training=True)
|
||||
param_coordinator.trace_prologue(sub_module)
|
||||
if param_coordinator.is_record_trace():
|
||||
param_coordinator.record_module(sub_module)
|
||||
@ -469,11 +470,12 @@ class DeepSpeedZeRoOffload(object):
|
||||
|
||||
@torch.no_grad()
|
||||
def post_sub_module_backward_function(self, sub_module):
|
||||
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
|
||||
see_memory_usage(
|
||||
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
|
||||
force=False)
|
||||
|
||||
self.get_param_coordinator(training=sub_module.training).release_sub_module(sub_module)
|
||||
self.get_param_coordinator(training=True).release_sub_module(sub_module)
|
||||
|
||||
see_memory_usage(
|
||||
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
|
||||
|
@ -11,7 +11,6 @@ from enum import Enum
|
||||
import functools
|
||||
import itertools
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from deepspeed import comm as dist
|
||||
@ -688,9 +687,10 @@ class Init(InsertPostInitMethodToModuleSubClasses):
|
||||
config_dict_or_path = config
|
||||
logger.warning(
|
||||
f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.')
|
||||
|
||||
_ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path,
|
||||
mpu) if config_dict_or_path is not None else None
|
||||
if _ds_config is not None:
|
||||
mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear
|
||||
super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype)
|
||||
if not dist.is_initialized():
|
||||
init_distributed()
|
||||
|
@ -138,10 +138,20 @@ class PartitionedParameterCoordinator:
|
||||
def trace_prologue(self, sub_module: Module) -> None:
|
||||
if self.is_complete_trace():
|
||||
# sub_module must match expectation else invalidate trace cache
|
||||
if len(self.__submodule_order) <= self.__step_id:
|
||||
print_rank_0(
|
||||
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: "
|
||||
f"cache has only {len(self.__submodule_order)} modules",
|
||||
force=True)
|
||||
self._invalidate_trace()
|
||||
return
|
||||
|
||||
if sub_module != self.__submodule_order[self.__step_id]:
|
||||
expected_module_id = self.__submodule_order[self.__step_id].id
|
||||
debug_rank0(f"Invalidate trace cache @ step {self.__step_id}: "
|
||||
f"expected module {expected_module_id}, but got module {sub_module.id}")
|
||||
print_rank_0(
|
||||
f"Invalidate trace cache @ step {self.__step_id}: "
|
||||
f"expected module {expected_module_id}, but got module {sub_module.id}",
|
||||
force=True)
|
||||
self._invalidate_trace()
|
||||
|
||||
def record_module(self, sub_module: Module) -> None:
|
||||
@ -187,7 +197,9 @@ class PartitionedParameterCoordinator:
|
||||
self.__submodule_order = tuple(self.__submodule_order) # freeze
|
||||
self.__param_order = tuple(self.__param_order) # freeze
|
||||
self.__trace_mode = ZeRoTraceMode.COMPLETE
|
||||
print_rank_0(f"completed record trace: {[m.id for m in self.__submodule_order]}", force=False)
|
||||
print_rank_0(
|
||||
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}",
|
||||
force=False)
|
||||
else:
|
||||
# Enable trace recording for next forward/backward pass
|
||||
self.__trace_mode = ZeRoTraceMode.RECORD
|
||||
|
@ -33,7 +33,7 @@ pg_correctness_test = False
|
||||
def print_rank_0(message, debug=False, force=False):
|
||||
rank = dist.get_rank()
|
||||
if rank == 0 and (debug or force):
|
||||
print(message)
|
||||
logger.info(message)
|
||||
# other variations
|
||||
# - print for all ranks w/o interleaving
|
||||
# printflock(f"[{rank}] {message}")
|
||||
@ -1015,6 +1015,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self.__reduce_and_partition_ipg_grads()
|
||||
|
||||
param_id = self.get_param_id(param)
|
||||
|
||||
assert self.params_already_reduced[param_id] == False, \
|
||||
f"The parameter {param_id} has already been reduced. \
|
||||
Gradient computed twice for this partition. \
|
||||
|
19
tests/hybrid_engine/hybrid_engine_config.json
Normal file
19
tests/hybrid_engine/hybrid_engine_config.json
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"train_batch_size" : 32,
|
||||
"train_micro_batch_size_per_gpu": 2,
|
||||
"steps_per_print": 10,
|
||||
"zero_optimization": {
|
||||
"stage": 0,
|
||||
"offload_param": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"stage3_param_persistence_threshold": 0
|
||||
},
|
||||
"fp16":{
|
||||
"enabled": true,
|
||||
"loss_scale_window": 100
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"prescale_gradients": false,
|
||||
"wall_clock_breakdown" : false
|
||||
}
|
30
tests/hybrid_engine/hybrid_engine_test.py
Normal file
30
tests/hybrid_engine/hybrid_engine_test.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
import deepspeed
|
||||
import argparse
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
deepspeed.runtime.utils.see_memory_usage('pre test', force=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained('facebook/opt-350M').half().to(get_accelerator().device_name())
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
deepspeed.runtime.utils.see_memory_usage('post test', force=True)
|
||||
|
||||
m, _, _, _ = deepspeed.initialize(model=model, args=args, enable_hybrid_engine=True)
|
||||
|
||||
m.eval()
|
||||
input = torch.ones(1, 16, device='cuda', dtype=torch.long)
|
||||
out = m(input)
|
||||
|
||||
m.train()
|
||||
out = m(input)
|
||||
print(out['logits'], out['logits'].norm())
|
@ -1 +1 @@
|
||||
0.8.3
|
||||
0.9.0
|
||||
|
Reference in New Issue
Block a user