Files
DeepSpeed/deepspeed/ops/transformer/inference/ds_attention.py
Ma, Guokai 9548d48f48 Abstract accelerator (step 2) (#2560)
* Abstract accelerator (step 2)

* more flex op_builder path for both installation and runtime

* add SpatialInferenceBuilder into cuda_accelerator.py

* use reflection to make cuda_accelerator adapt to CUDA op builder change automatically

* clean up deepspeed/__init__.py

* add comments in cuda_accelerator for no torch path

* Update deepspeed/env_report.py

Change env_report.py according to suggestion

Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>

* reduce the range of try...except for better code clarity

* Add porting for deepspeed/ops/random_ltd/dropping_utils.py

* move accelerator to top directory and create symlink under deepspeed

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
2023-01-06 23:40:58 -05:00

286 lines
13 KiB
Python

'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import math
import torch
import torch.nn as nn
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from .op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp, SoftmaxOp
minus_inf = -10000.0
class DeepSpeedSelfAttention(nn.Module):
num_layers = 0
def __init__(self,
config,
mp_group=None,
q_scales=None,
q_groups=1,
merge_count=1,
qkv_merging=False):
super(DeepSpeedSelfAttention, self).__init__()
self.config = config
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
data_type_fp = torch.half if config.fp16 else torch.float
self.config.layer_id = DeepSpeedSelfAttention.num_layers
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = get_accelerator().current_device_name(
) #if config.bigscience_bloom else 'cpu'
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
qkv_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition,
dtype=data_type_fp,
device=device),
requires_grad=False)
out_size_per_partition = self.config.hidden_size // self.config.mp_size
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp,
device=device),
requires_grad=False)
self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
self.mp_group = mp_group
# used for quantization
self.q_scales = q_scales
self.q_groups = q_groups
self.merge_count = int(math.log2(merge_count))
self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
if not config.use_mup:
self.norm_factor = math.sqrt(self.norm_factor)
self.qkv_merging = qkv_merging
if self.config.scale_attn_by_inverse_layer_idx is True:
self.norm_factor *= math.sqrt(self.config.layer_id + 1)
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
self.qkv_func = QKVGemmOp(config)
self.score_context_func = SoftmaxContextOp(config)
self.linear_func = LinearOp(config)
self.vector_matmul_func = VectorMatMulOp(config)
def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if isinstance(qkv_out, list):
qkv_out = qkv_out[0]
no_masking = input_mask is None
if no_masking:
input_mask = torch.empty(1)
attn_key_value = self.score_context_func(
query_key_value=qkv_out,
attn_mask=((1 - input_mask).to(qkv_out.dtype) *
minus_inf) if input_mask.dtype == torch.int64 else input_mask,
heads=self.num_attention_heads_per_partition,
norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
no_masking=no_masking,
layer_id=self.config.layer_id,
num_layers=DeepSpeedSelfAttention.num_layers,
alibi=alibi)
context_layer, key_layer, value_layer = attn_key_value
return context_layer, key_layer, value_layer
def forward(self,
input,
input_mask,
head_mask=None,
layer_past=None,
get_present=False,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
norm_w=None,
norm_b=None,
alibi=None):
if not self.config.pre_layer_norm:
qkv_out = self.linear_func(input=input,
weight=self.attn_qkvw,
bias=self.attn_qkvb,
add_bias=self.attn_qkvb is not None,
do_flash_attn=False,
num_heads=self.num_attention_heads_per_partition,
num_layers=DeepSpeedSelfAttention.num_layers)
else:
qkv_out = self.qkv_func(
input=input,
weight=self.attn_qkvw,
bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
gamma=norm_w,
beta=norm_b,
add_bias=(self.attn_qkvb is not None),
num_layers=DeepSpeedSelfAttention.num_layers,
num_heads=self.num_attention_heads_per_partition)
context_layer, key_layer, value_layer = self.compute_attention(
qkv_out=qkv_out,
input_mask=input_mask,
layer_past=layer_past,
alibi=alibi)
output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
inp_norm = qkv_out[-1]
if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(
group=self.mp_group) > 1:
dist.all_reduce(output, group=self.mp_group)
return (output, key_layer, value_layer, context_layer, inp_norm)
class BloomSelfAttention(DeepSpeedSelfAttention):
def __init__(self, *args, **kwargs):
super(BloomSelfAttention, self).__init__(*args, **kwargs)
self.softmax_func = SoftmaxOp(self.config)
########### This part is taken/modified form the HF modeling_bloom.py ################
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py
def _transpose_for_context(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_layer_shape = x.size()[:-2] + \
(self.hidden_size_per_partition,)
return x.view(*new_x_layer_shape).contiguous()
def _split_tensor_along_last_dim(self,
tensor,
num_partitions,
contiguous_split_chunks=True):
"""Split a tensor along its last dimension.
Args:
tensor: ([`torch.tensor`], *required*):
input tensor to split
num_partitions ([`int`], *required*):
number of partitions to split the tensor
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
If True, make each chunk contiguous in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
numerator, denominator = tensor.size()[last_dim], num_partitions
if not (numerator % denominator == 0):
raise ValueError(f"{numerator} is not divisible by {denominator}")
last_dim_size = numerator // denominator
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if isinstance(qkv_out, list):
qkv_out = qkv_out[0]
no_masking = input_mask is None
if no_masking:
input_mask = torch.empty(1)
mixed_x_layer = qkv_out
alibi = alibi.to(get_accelerator().current_device_name())
head_dim = self.hidden_size_per_partition // self.num_attention_heads_per_partition
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
query_layer, key_layer, value_layer = self._split_tensor_along_last_dim(mixed_x_layer, 3)
# [batch_size, head_dim, q_length, k_length]
output_size = (query_layer.size(0),
query_layer.size(2),
query_layer.size(1),
key_layer.size(1))
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer = query_layer.transpose(1,
2).reshape(output_size[0] * output_size[1],
output_size[2],
-1)
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1,
2).reshape(output_size[0] * output_size[1],
output_size[3],
-1).transpose(-1,
-2)
value_layer = value_layer.transpose(1,
2).reshape(output_size[0] * output_size[1],
output_size[3],
-1)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer),
dim=-2)
presents = (key_layer, value_layer)
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
matmul_result = torch.matmul(query_layer, key_layer)
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(output_size[0],
output_size[1],
output_size[2],
-1)
offset = dist.get_rank(
) * self.num_attention_heads_per_partition if dist.is_initialized() else 0
attention_probs = self.softmax_func(
attn_scores=attention_scores,
attn_mask=((1 - input_mask).half() *
minus_inf) if input_mask.dtype == torch.int64 else input_mask,
alibi=alibi,
triangular=(self.config.triangular_masking
and (attention_scores.shape[-2] > 1)),
recompute=False,
local_attention=False,
window_size=1,
async_op=False,
layer_scale=1 / (self.norm_factor * self.norm_factor),
head_offset=offset)
# change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view(
context_layer.size(0) // self.num_attention_heads_per_partition,
self.num_attention_heads_per_partition,
context_layer.size(1),
context_layer.shape[-1])
context_layer = self._transpose_for_context(context_layer)
key_layer = presents[0]
value_layer = presents[1]
return context_layer, key_layer, value_layer
###################### End of HF modeling_bloom addition ########################