mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
* Abstract accelerator (step 2) * more flex op_builder path for both installation and runtime * add SpatialInferenceBuilder into cuda_accelerator.py * use reflection to make cuda_accelerator adapt to CUDA op builder change automatically * clean up deepspeed/__init__.py * add comments in cuda_accelerator for no torch path * Update deepspeed/env_report.py Change env_report.py according to suggestion Co-authored-by: Michael Wyatt <mrwyattii@gmail.com> * reduce the range of try...except for better code clarity * Add porting for deepspeed/ops/random_ltd/dropping_utils.py * move accelerator to top directory and create symlink under deepspeed Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Michael Wyatt <mrwyattii@gmail.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
286 lines
13 KiB
Python
286 lines
13 KiB
Python
'''
|
|
Copyright 2022 The Microsoft DeepSpeed Team
|
|
'''
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from deepspeed import comm as dist
|
|
from deepspeed.accelerator import get_accelerator
|
|
from .op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp, SoftmaxOp
|
|
|
|
minus_inf = -10000.0
|
|
|
|
|
|
class DeepSpeedSelfAttention(nn.Module):
|
|
num_layers = 0
|
|
|
|
def __init__(self,
|
|
config,
|
|
mp_group=None,
|
|
q_scales=None,
|
|
q_groups=1,
|
|
merge_count=1,
|
|
qkv_merging=False):
|
|
super(DeepSpeedSelfAttention, self).__init__()
|
|
self.config = config
|
|
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
|
|
data_type_fp = torch.half if config.fp16 else torch.float
|
|
self.config.layer_id = DeepSpeedSelfAttention.num_layers
|
|
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
|
|
device = get_accelerator().current_device_name(
|
|
) #if config.bigscience_bloom else 'cpu'
|
|
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
|
|
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
|
|
qkv_size_per_partition,
|
|
dtype=data_type,
|
|
device=device),
|
|
requires_grad=False)
|
|
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition,
|
|
dtype=data_type_fp,
|
|
device=device),
|
|
requires_grad=False)
|
|
out_size_per_partition = self.config.hidden_size // self.config.mp_size
|
|
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
|
|
self.config.hidden_size,
|
|
dtype=data_type,
|
|
device=device),
|
|
requires_grad=False)
|
|
|
|
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size,
|
|
dtype=data_type_fp,
|
|
device=device),
|
|
requires_grad=False)
|
|
|
|
self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
|
|
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
|
|
self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
|
|
|
|
self.mp_group = mp_group
|
|
|
|
# used for quantization
|
|
self.q_scales = q_scales
|
|
self.q_groups = q_groups
|
|
self.merge_count = int(math.log2(merge_count))
|
|
|
|
self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
|
|
if not config.use_mup:
|
|
self.norm_factor = math.sqrt(self.norm_factor)
|
|
self.qkv_merging = qkv_merging
|
|
|
|
if self.config.scale_attn_by_inverse_layer_idx is True:
|
|
self.norm_factor *= math.sqrt(self.config.layer_id + 1)
|
|
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
|
|
|
|
self.qkv_func = QKVGemmOp(config)
|
|
self.score_context_func = SoftmaxContextOp(config)
|
|
self.linear_func = LinearOp(config)
|
|
self.vector_matmul_func = VectorMatMulOp(config)
|
|
|
|
def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
|
|
if isinstance(qkv_out, list):
|
|
qkv_out = qkv_out[0]
|
|
|
|
no_masking = input_mask is None
|
|
|
|
if no_masking:
|
|
input_mask = torch.empty(1)
|
|
|
|
attn_key_value = self.score_context_func(
|
|
query_key_value=qkv_out,
|
|
attn_mask=((1 - input_mask).to(qkv_out.dtype) *
|
|
minus_inf) if input_mask.dtype == torch.int64 else input_mask,
|
|
heads=self.num_attention_heads_per_partition,
|
|
norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
|
|
no_masking=no_masking,
|
|
layer_id=self.config.layer_id,
|
|
num_layers=DeepSpeedSelfAttention.num_layers,
|
|
alibi=alibi)
|
|
|
|
context_layer, key_layer, value_layer = attn_key_value
|
|
return context_layer, key_layer, value_layer
|
|
|
|
def forward(self,
|
|
input,
|
|
input_mask,
|
|
head_mask=None,
|
|
layer_past=None,
|
|
get_present=False,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
output_attentions=False,
|
|
norm_w=None,
|
|
norm_b=None,
|
|
alibi=None):
|
|
|
|
if not self.config.pre_layer_norm:
|
|
qkv_out = self.linear_func(input=input,
|
|
weight=self.attn_qkvw,
|
|
bias=self.attn_qkvb,
|
|
add_bias=self.attn_qkvb is not None,
|
|
do_flash_attn=False,
|
|
num_heads=self.num_attention_heads_per_partition,
|
|
num_layers=DeepSpeedSelfAttention.num_layers)
|
|
else:
|
|
qkv_out = self.qkv_func(
|
|
input=input,
|
|
weight=self.attn_qkvw,
|
|
bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
|
|
gamma=norm_w,
|
|
beta=norm_b,
|
|
add_bias=(self.attn_qkvb is not None),
|
|
num_layers=DeepSpeedSelfAttention.num_layers,
|
|
num_heads=self.num_attention_heads_per_partition)
|
|
|
|
context_layer, key_layer, value_layer = self.compute_attention(
|
|
qkv_out=qkv_out,
|
|
input_mask=input_mask,
|
|
layer_past=layer_past,
|
|
alibi=alibi)
|
|
|
|
output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
|
|
|
|
inp_norm = qkv_out[-1]
|
|
|
|
if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(
|
|
group=self.mp_group) > 1:
|
|
dist.all_reduce(output, group=self.mp_group)
|
|
|
|
return (output, key_layer, value_layer, context_layer, inp_norm)
|
|
|
|
|
|
class BloomSelfAttention(DeepSpeedSelfAttention):
|
|
def __init__(self, *args, **kwargs):
|
|
super(BloomSelfAttention, self).__init__(*args, **kwargs)
|
|
self.softmax_func = SoftmaxOp(self.config)
|
|
|
|
########### This part is taken/modified form the HF modeling_bloom.py ################
|
|
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py
|
|
|
|
def _transpose_for_context(self, x):
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
new_x_layer_shape = x.size()[:-2] + \
|
|
(self.hidden_size_per_partition,)
|
|
return x.view(*new_x_layer_shape).contiguous()
|
|
|
|
def _split_tensor_along_last_dim(self,
|
|
tensor,
|
|
num_partitions,
|
|
contiguous_split_chunks=True):
|
|
"""Split a tensor along its last dimension.
|
|
|
|
Args:
|
|
tensor: ([`torch.tensor`], *required*):
|
|
input tensor to split
|
|
num_partitions ([`int`], *required*):
|
|
number of partitions to split the tensor
|
|
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
|
|
If True, make each chunk contiguous in memory.
|
|
"""
|
|
# Get the size and dimension.
|
|
last_dim = tensor.dim() - 1
|
|
numerator, denominator = tensor.size()[last_dim], num_partitions
|
|
if not (numerator % denominator == 0):
|
|
raise ValueError(f"{numerator} is not divisible by {denominator}")
|
|
last_dim_size = numerator // denominator
|
|
# Split.
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
# Note: torch.split does not create contiguous tensors by default.
|
|
if contiguous_split_chunks:
|
|
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
|
|
return tensor_list
|
|
|
|
def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
|
|
if isinstance(qkv_out, list):
|
|
qkv_out = qkv_out[0]
|
|
|
|
no_masking = input_mask is None
|
|
|
|
if no_masking:
|
|
input_mask = torch.empty(1)
|
|
|
|
mixed_x_layer = qkv_out
|
|
alibi = alibi.to(get_accelerator().current_device_name())
|
|
head_dim = self.hidden_size_per_partition // self.num_attention_heads_per_partition
|
|
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
|
self.num_attention_heads_per_partition,
|
|
3 * head_dim)
|
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
|
|
|
query_layer, key_layer, value_layer = self._split_tensor_along_last_dim(mixed_x_layer, 3)
|
|
|
|
# [batch_size, head_dim, q_length, k_length]
|
|
output_size = (query_layer.size(0),
|
|
query_layer.size(2),
|
|
query_layer.size(1),
|
|
key_layer.size(1))
|
|
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
|
|
query_layer = query_layer.transpose(1,
|
|
2).reshape(output_size[0] * output_size[1],
|
|
output_size[2],
|
|
-1)
|
|
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
|
|
key_layer = key_layer.transpose(1,
|
|
2).reshape(output_size[0] * output_size[1],
|
|
output_size[3],
|
|
-1).transpose(-1,
|
|
-2)
|
|
value_layer = value_layer.transpose(1,
|
|
2).reshape(output_size[0] * output_size[1],
|
|
output_size[3],
|
|
-1)
|
|
if layer_past is not None:
|
|
past_key, past_value = layer_past
|
|
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
|
|
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
|
|
value_layer = torch.cat((past_value.type_as(value_layer),
|
|
value_layer),
|
|
dim=-2)
|
|
|
|
presents = (key_layer, value_layer)
|
|
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
|
|
matmul_result = torch.matmul(query_layer, key_layer)
|
|
# change view to [batch_size, num_heads, q_length, k_length]
|
|
attention_scores = matmul_result.view(output_size[0],
|
|
output_size[1],
|
|
output_size[2],
|
|
-1)
|
|
|
|
offset = dist.get_rank(
|
|
) * self.num_attention_heads_per_partition if dist.is_initialized() else 0
|
|
attention_probs = self.softmax_func(
|
|
attn_scores=attention_scores,
|
|
attn_mask=((1 - input_mask).half() *
|
|
minus_inf) if input_mask.dtype == torch.int64 else input_mask,
|
|
alibi=alibi,
|
|
triangular=(self.config.triangular_masking
|
|
and (attention_scores.shape[-2] > 1)),
|
|
recompute=False,
|
|
local_attention=False,
|
|
window_size=1,
|
|
async_op=False,
|
|
layer_scale=1 / (self.norm_factor * self.norm_factor),
|
|
head_offset=offset)
|
|
|
|
# change view [batch_size x num_heads, q_length, k_length]
|
|
attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
|
|
|
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
|
|
|
# change view [batch_size, num_heads, q_length, head_dim]
|
|
context_layer = context_layer.view(
|
|
context_layer.size(0) // self.num_attention_heads_per_partition,
|
|
self.num_attention_heads_per_partition,
|
|
context_layer.size(1),
|
|
context_layer.shape[-1])
|
|
|
|
context_layer = self._transpose_for_context(context_layer)
|
|
key_layer = presents[0]
|
|
value_layer = presents[1]
|
|
|
|
return context_layer, key_layer, value_layer
|
|
|
|
###################### End of HF modeling_bloom addition ########################
|