mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
* Abstract accelerator (step 2) * more flex op_builder path for both installation and runtime * add SpatialInferenceBuilder into cuda_accelerator.py * use reflection to make cuda_accelerator adapt to CUDA op builder change automatically * clean up deepspeed/__init__.py * add comments in cuda_accelerator for no torch path * Update deepspeed/env_report.py Change env_report.py according to suggestion Co-authored-by: Michael Wyatt <mrwyattii@gmail.com> * reduce the range of try...except for better code clarity * Add porting for deepspeed/ops/random_ltd/dropping_utils.py * move accelerator to top directory and create symlink under deepspeed Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Michael Wyatt <mrwyattii@gmail.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
467 lines
21 KiB
Python
467 lines
21 KiB
Python
'''
|
|
Copyright 2020 The Microsoft DeepSpeed Team
|
|
'''
|
|
import json
|
|
import math
|
|
import torch
|
|
from torch.autograd import Function
|
|
#from ...inference.engine import inference_cuda_module, specialized_mode
|
|
# Cuda modules will be imported if needed
|
|
inference_cuda_module = None
|
|
specialized_mode = None
|
|
import torch.nn as nn
|
|
from .ds_attention import DeepSpeedSelfAttention
|
|
from .config import DeepSpeedInferenceConfig
|
|
from ....moe.sharded_moe import TopKGate
|
|
from deepspeed import comm as dist
|
|
from deepspeed.accelerator import get_accelerator
|
|
from deepspeed.ops.op_builder.builder_names import InferenceBuilder
|
|
|
|
|
|
class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
|
|
"""Initialize the DeepSpeed Transformer Config.
|
|
Arguments:
|
|
hidden_size: The hidden size of the transformer layer
|
|
intermediate_size: The intermediate size of the feed-forward part of transformer layer
|
|
heads: The number of heads in the self-attention of the transformer layer
|
|
num_hidden_layers: The number of transformer layers
|
|
layer_norm_eps: The epsilon value for the layer norm
|
|
local_rank: Optional: The rank of GPU running the transformer kernel, it is not required
|
|
to use if the model already set the current device, otherwise need to set it
|
|
so that the transformer kernel can work on the right device
|
|
mp_size (optional): This argument is mainly used to create the parameters on the kernel side
|
|
using model-parallel architecture. If the client model already takes care of this, there is no
|
|
need to pass this argument.
|
|
fp16: Enable half-precision computation
|
|
pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
|
|
stochastic_mode: Enable for high performance, please note that this flag has some level of
|
|
non-determinism and can produce different results on different runs. However, we have seen
|
|
that by enabling it, the pretraining tasks such as BERT are not affected and can obtain
|
|
a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
|
|
to turn it off in order to be able to reproduce the same result through the regular kernel execution.
|
|
|
|
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
|
|
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
|
|
"""
|
|
def __init__(self,
|
|
hidden_size=-1,
|
|
intermediate_size=-1,
|
|
heads=-1,
|
|
num_hidden_layers=-1,
|
|
layer_norm_eps=1e-12,
|
|
local_rank=-1,
|
|
mp_size=1,
|
|
fp16=False,
|
|
q_int8=False,
|
|
pre_layer_norm=True,
|
|
stochastic_mode=False,
|
|
scale_attention=True,
|
|
triangular_masking=True,
|
|
local_attention=False,
|
|
window_size=256,
|
|
return_tuple=True,
|
|
moe_experts=1,
|
|
global_experts=1,
|
|
k=1,
|
|
capacity_factor=1.,
|
|
eval_capacity_factor=1.,
|
|
min_capacity=1,
|
|
noisy_gate_policy=None,
|
|
drop_tokens=True,
|
|
use_rts=False,
|
|
mlp_type='standard',
|
|
scale_attn_by_inverse_layer_idx=False):
|
|
super(DeepSpeedMoEInferenceConfig,
|
|
self).__init__(
|
|
hidden_size,
|
|
(intermediate_size if intermediate_size > 0 else 4 * hidden_size),
|
|
heads,
|
|
num_hidden_layers,
|
|
layer_norm_eps,
|
|
local_rank,
|
|
mp_size,
|
|
fp16,
|
|
q_int8,
|
|
pre_layer_norm,
|
|
stochastic_mode,
|
|
scale_attention,
|
|
triangular_masking,
|
|
local_attention,
|
|
window_size,
|
|
return_tuple)
|
|
self.moe_experts = moe_experts
|
|
self.k = k
|
|
self.capacity_factor = capacity_factor
|
|
self.eval_capacity_factor = eval_capacity_factor
|
|
self.min_capacity = min_capacity
|
|
self.noisy_gate_policy = noisy_gate_policy
|
|
self.drop_tokens = drop_tokens
|
|
self.use_rts = use_rts
|
|
self.global_experts = global_experts
|
|
self.mlp_type = mlp_type
|
|
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
|
|
|
@classmethod
|
|
def from_dict(cls, json_object):
|
|
config = DeepSpeedInferenceConfig()
|
|
for key, value in json_object.items():
|
|
config.__dict__[key] = value
|
|
return config
|
|
|
|
@classmethod
|
|
def from_json_file(cls, json_file):
|
|
with open(json_file, "r", encoding='utf-8') as reader:
|
|
text = reader.read()
|
|
return cls.from_dict(json.loads(text))
|
|
|
|
|
|
class DeepSpeedMLPFunction(Function):
|
|
@staticmethod
|
|
def forward(ctx,
|
|
input,
|
|
inter_w,
|
|
inter_b,
|
|
config,
|
|
output_b,
|
|
output_w,
|
|
q_scales,
|
|
q_groups,
|
|
merge_count,
|
|
mp_group,
|
|
async_op):
|
|
if config.q_int8:
|
|
intermediate = inference_cuda_module.fused_gemm_gelu_int8(
|
|
input,
|
|
inter_w,
|
|
inter_b,
|
|
config.epsilon,
|
|
q_scales[2],
|
|
(q_groups * (2**merge_count)),
|
|
config.pre_layer_norm)
|
|
output = inference_cuda_module.vector_matmul_int8(intermediate,
|
|
output_w,
|
|
q_scales[3],
|
|
q_groups,
|
|
(merge_count))
|
|
else:
|
|
mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
|
|
inference_cuda_module.fused_gemm_gelu_fp32
|
|
|
|
output = mlp_gemm_func(input,
|
|
inter_w,
|
|
inter_b,
|
|
output_w,
|
|
config.epsilon,
|
|
config.pre_layer_norm,
|
|
async_op)
|
|
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
|
|
dist.all_reduce(output, group=mp_group, async_op=async_op)
|
|
|
|
return output + output_b
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
raise RuntimeError('You are running with DeepSpeed Inference mode. \
|
|
Please switch to Training mode for running backward!')
|
|
|
|
|
|
class DeepSpeedMoEMLP(nn.Module):
|
|
def __init__(self,
|
|
config,
|
|
q_scales=None,
|
|
q_groups=1,
|
|
merge_count=1,
|
|
mlp_extra_grouping=False,
|
|
mp_group=None):
|
|
super(DeepSpeedMoEMLP, self).__init__()
|
|
|
|
self.config = config
|
|
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
|
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
|
interm_size = self.config.intermediate_size // (
|
|
1 if mp_group is None else dist.get_world_size(group=mp_group))
|
|
self.inter_w = nn.Parameter(torch.Tensor(self.config.hidden_size, interm_size))
|
|
self.inter_b = nn.Parameter(torch.Tensor(interm_size))
|
|
self.output_w = nn.Parameter(torch.Tensor((interm_size),
|
|
self.config.hidden_size))
|
|
self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
|
|
|
# used for quantization
|
|
self.q_scales = q_scales
|
|
self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups
|
|
self.merge_count = int(math.log2(merge_count))
|
|
self.mp_group = mp_group
|
|
|
|
def forward(self, input, async_op=False):
|
|
return DeepSpeedMLPFunction.apply(input,
|
|
self.inter_w,
|
|
self.inter_b,
|
|
self.config,
|
|
self.output_b,
|
|
self.output_w,
|
|
self.q_scales,
|
|
self.q_groups,
|
|
self.merge_count,
|
|
self.mp_group,
|
|
async_op)
|
|
|
|
|
|
class DeepSpeedMoEInference(nn.Module):
|
|
"""Initialize the DeepSpeed MoE Transformer Layer.
|
|
Arguments:
|
|
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
|
|
layer_id will be 0,1,2...23 when each layer object is instantiated
|
|
config: An object of DeepSpeedInferenceConfig
|
|
mp_group: Model parallelism group initialized on the modeling side.
|
|
quantize_scales: This argument groups all the layers' scales used for quantization
|
|
quantize_groups: Number of groups used for quantizing the model
|
|
merge_count: Shows the number of model-parallel checkpoints merged before running inference.
|
|
We use this argument to control the quantization scale for the model parameters if a bigger
|
|
quantize-grouping than 1 is used.
|
|
mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
|
|
of a Transformer layer. We use this feature for quantization to reduce the convergence impact
|
|
for specific downstream tasks.
|
|
"""
|
|
layer_id = 0
|
|
|
|
def __init__(self,
|
|
config,
|
|
mp_group=None,
|
|
ep_group=None,
|
|
expert_mp_group=None,
|
|
quantize_scales=None,
|
|
quantize_groups=1,
|
|
merge_count=1,
|
|
mlp_extra_grouping=False,
|
|
qkv_merging=False):
|
|
super(DeepSpeedMoEInference, self).__init__()
|
|
|
|
self.config = config
|
|
self.config.layer_id = DeepSpeedMoEInference.layer_id
|
|
global inference_cuda_module
|
|
global specialized_mode
|
|
if inference_cuda_module is None:
|
|
specialized_mode = False
|
|
# InferenceSpecializedBuilder is not among DeepSpeed provided builder yet, so we infer by builder name string
|
|
builder = get_accelerator().create_op_builder("InferenceSpecializedBuilder")
|
|
if builder != None and builder.is_compatible():
|
|
inference_cuda_module = builder.load()
|
|
specialized_mode = True
|
|
else:
|
|
inference_cuda_module = get_accelerator().create_op_builder(
|
|
InferenceBuilder).load()
|
|
self.config.specialized_mode = specialized_mode
|
|
|
|
DeepSpeedMoEInference.layer_id += 1
|
|
self.attention = DeepSpeedSelfAttention(self.config,
|
|
mp_group,
|
|
quantize_scales,
|
|
quantize_groups,
|
|
merge_count,
|
|
qkv_merging)
|
|
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
|
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
|
|
|
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
|
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
|
|
|
if config.mlp_type == 'residual':
|
|
self.res_mlp = DeepSpeedMoEMLP(config,
|
|
quantize_scales,
|
|
quantize_groups,
|
|
merge_count,
|
|
mlp_extra_grouping,
|
|
mp_group)
|
|
self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2))
|
|
self.coef_func = inference_cuda_module.softmax_fp16 if self.config.fp16 or self.config.q_int8 else \
|
|
inference_cuda_module.softmax_fp32
|
|
self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
|
|
inference_cuda_module.vector_matmul_fp32
|
|
|
|
config.mp_size = 1
|
|
self.mlp = nn.ModuleList(
|
|
DeepSpeedMoEMLP(config,
|
|
quantize_scales,
|
|
quantize_groups,
|
|
merge_count,
|
|
mlp_extra_grouping,
|
|
expert_mp_group) for i in range(self.config.moe_experts))
|
|
|
|
self.moe_gate = TopKGate(self.config.hidden_size,
|
|
self.config.global_experts,
|
|
self.config.k,
|
|
self.config.capacity_factor,
|
|
self.config.eval_capacity_factor,
|
|
self.config.min_capacity,
|
|
self.config.noisy_gate_policy,
|
|
self.config.drop_tokens,
|
|
self.config.use_rts)
|
|
|
|
self.ep_group = ep_group
|
|
self.mp_group = mp_group
|
|
self.expert_mp_group = expert_mp_group
|
|
|
|
print("DeepSpeed MoE Transformer Inference config is ", self.config.__dict__)
|
|
|
|
self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \
|
|
inference_cuda_module.bias_residual_fp32
|
|
self.ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \
|
|
inference_cuda_module.layer_norm_fp32
|
|
self.einsum_sec_sm_ecm = inference_cuda_module.einsum_sec_sm_ecm_fp16 if self.config.fp16 or self.config.q_int8 else \
|
|
inference_cuda_module.einsum_sec_sm_ecm_fp32
|
|
|
|
def res_coef_func(self, inp, async_op):
|
|
inp = self.vector_matmul_func(inp, self.res_coef, async_op)
|
|
return self.coef_func(inp, torch.empty(1), False, False, False, 256, async_op)
|
|
|
|
def moe_gate_einsum(self, attention_output):
|
|
_, combined_weights, dispatch_mask, _ = self.moe_gate(
|
|
attention_output.view(-1, self.config.hidden_size),
|
|
None,
|
|
)
|
|
dispatched_attention = self.einsum_sec_sm_ecm(
|
|
dispatch_mask.type_as(attention_output),
|
|
attention_output.view(-1,
|
|
self.config.hidden_size))
|
|
return dispatched_attention, combined_weights
|
|
|
|
def expert_exec(self, dispatched_input):
|
|
dispatched_input = dispatched_input.reshape(
|
|
self.config.global_experts // self.config.moe_experts,
|
|
self.config.moe_experts,
|
|
-1,
|
|
self.config.hidden_size)
|
|
|
|
chunks = dispatched_input.chunk(self.config.moe_experts, dim=1)
|
|
expert_outputs = torch.empty((
|
|
self.config.moe_experts,
|
|
chunks[0].shape[0],
|
|
) + chunks[0].shape[2:],
|
|
dtype=dispatched_input.dtype,
|
|
device=dispatched_input.device)
|
|
for chunk, expert in zip(chunks, range(len(self.mlp))):
|
|
expert_outputs[expert] = self.mlp[expert](chunk.view(
|
|
-1,
|
|
dispatched_input.shape[-2],
|
|
dispatched_input.shape[-1]))
|
|
return expert_outputs
|
|
|
|
def _alltoall(self, dispatched_attention):
|
|
if dist.get_world_size(group=self.ep_group) > 1:
|
|
dispatched_input = torch.empty_like(dispatched_attention)
|
|
dist.all_to_all_single(dispatched_input,
|
|
dispatched_attention,
|
|
group=self.ep_group)
|
|
return dispatched_input
|
|
else:
|
|
return dispatched_attention
|
|
|
|
def scale_expert_output(self, attention_output, expert_output, combined_weights):
|
|
combined_output = torch.matmul(
|
|
combined_weights.type_as(attention_output).reshape(
|
|
combined_weights.shape[0],
|
|
-1),
|
|
expert_output.reshape(-1,
|
|
expert_output.shape[-1]))
|
|
return combined_output.reshape(attention_output.shape)
|
|
|
|
def forward(self,
|
|
input,
|
|
input_mask=None,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
get_present=False,
|
|
encoder_output=None,
|
|
enc_dec_attn_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
use_cache=False,
|
|
output_attentions=False):
|
|
get_present = (get_present or get_key_value or use_cache)
|
|
input_mask = input_mask if attention_mask is None else attention_mask
|
|
input_type = input.dtype
|
|
|
|
if (self.config.fp16 or self.config.q_int8) \
|
|
and input.dtype == torch.float:
|
|
input = input.half()
|
|
|
|
with torch.no_grad():
|
|
attention_output = self.attention(input,
|
|
input_mask,
|
|
head_mask,
|
|
layer_past,
|
|
get_present,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
output_attentions,
|
|
self.norm_w,
|
|
self.norm_b)
|
|
|
|
if get_present:
|
|
attention_output, p_key, p_value = attention_output[0:3]
|
|
presents = (p_key, p_value)
|
|
elif output_attentions:
|
|
attention_output, _, _, context_output = attention_output[0:4]
|
|
else:
|
|
attention_output = attention_output[0]
|
|
|
|
residual_add = attention_output + self.attention.attn_ob
|
|
attention_output = self.ds_layernorm(residual_add,
|
|
self.attn_nw,
|
|
self.attn_nb,
|
|
self.config.epsilon)
|
|
|
|
if self.config.mlp_type == 'residual':
|
|
res_mlp_out = self.res_mlp(attention_output, async_op=True)
|
|
res_coef_out = self.res_coef_func(attention_output, async_op=True)
|
|
|
|
if self.expert_mp_group is not None:
|
|
tensor_list = [
|
|
torch.empty_like(attention_output)
|
|
for _ in range(dist.get_world_size(group=self.expert_mp_group))
|
|
]
|
|
tensor_list[dist.get_rank(group=self.expert_mp_group)] = attention_output
|
|
dist.all_gather(tensor_list,
|
|
attention_output,
|
|
group=self.expert_mp_group)
|
|
attention_output = torch.cat(tensor_list).contiguous()
|
|
|
|
############## MoE Gating + Experts ###############
|
|
dispatched_attention, combined_weights = self.moe_gate_einsum(attention_output)
|
|
dispatched_input = self._alltoall(dispatched_attention)
|
|
expert_outputs = self.expert_exec(dispatched_input)
|
|
expert_output = self._alltoall(expert_outputs)
|
|
output = self.scale_expert_output(attention_output,
|
|
expert_output,
|
|
combined_weights)
|
|
################################################
|
|
|
|
if self.expert_mp_group is not None:
|
|
output = output.split(output.shape[0] //
|
|
dist.get_world_size(group=self.expert_mp_group),
|
|
dim=0)[dist.get_rank(group=self.expert_mp_group)]
|
|
|
|
if self.config.mlp_type == 'residual':
|
|
inference_cuda_module.moe_res_matmul(res_mlp_out, res_coef_out, output)
|
|
|
|
output = self.bias_residual_func(output, residual_add, torch.empty(1))
|
|
|
|
if not self.config.pre_layer_norm:
|
|
output = self.ds_layernorm(output,
|
|
self.norm_w,
|
|
self.norm_b,
|
|
self.config.epsilon)
|
|
|
|
if input_type != output.dtype:
|
|
output = output.to(input_type)
|
|
|
|
if get_present:
|
|
output = (output, presents)
|
|
|
|
if self.config.return_tuple:
|
|
return output if type(output) is tuple else (output, )
|
|
else:
|
|
return output
|