mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix invalid f-strings detected by ruff. --------- Signed-off-by: cyy <cyyever@outlook.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com>
195 lines
8.6 KiB
Python
195 lines
8.6 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from deepspeed import comm as dist
|
|
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
|
|
from deepspeed.utils.logging import log_dist
|
|
|
|
from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
|
|
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
|
|
from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp
|
|
from deepspeed.accelerator import get_accelerator
|
|
import deepspeed
|
|
if deepspeed.HAS_TRITON and get_accelerator().is_triton_supported():
|
|
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
|
|
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention
|
|
|
|
|
|
class DeepSpeedTransformerInference(nn.Module):
|
|
"""Initialize the DeepSpeed Transformer Layer.
|
|
Arguments:
|
|
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
|
|
layer_id will be 0,1,2...23 when each layer object is instantiated
|
|
config: An object of DeepSpeedInferenceConfig
|
|
mp_group: Model parallelism group initialized on the modeling side.
|
|
quantize_scales: This argument groups all the layers' scales used for quantization
|
|
quantize_groups: Number of groups used for quantizing the model
|
|
merge_count: Shows the number of model-parallel checkpoints merged before running inference.
|
|
We use this argument to control the quantization scale for the model parameters if a bigger
|
|
quantize-grouping than 1 is used.
|
|
mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
|
|
of a Transformer layer. We use this feature for quantization to reduce the convergence impact
|
|
for specific downstream tasks.
|
|
"""
|
|
layer_id = 0
|
|
workspace = None
|
|
|
|
def __init__(self,
|
|
config,
|
|
mp_group=None,
|
|
quantize_scales=None,
|
|
quantize_groups=1,
|
|
merge_count=1,
|
|
mlp_extra_grouping=False):
|
|
super(DeepSpeedTransformerInference, self).__init__()
|
|
|
|
self.config = config
|
|
self.config.layer_id = DeepSpeedTransformerInference.layer_id
|
|
DeepSpeedTransformerInference.layer_id += 1
|
|
|
|
data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
|
|
|
|
if DeepSpeedTransformerInference.layer_id == 1:
|
|
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
|
|
if deepspeed.HAS_TRITON and self.config.use_triton:
|
|
log_dist("Injecting Triton kernels ...", [0])
|
|
|
|
if self.config.bigscience_bloom:
|
|
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
|
|
assert not self.config.use_triton
|
|
else:
|
|
if deepspeed.HAS_TRITON and self.config.use_triton:
|
|
self.attention = TritonSelfAttention(self.config)
|
|
else:
|
|
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
|
|
merge_count)
|
|
|
|
if deepspeed.HAS_TRITON and self.config.use_triton:
|
|
self.mlp = TritonMLP(self.config)
|
|
else:
|
|
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
|
|
mlp_extra_grouping)
|
|
|
|
device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
|
|
if self.config.set_empty_params:
|
|
self.norm_w = None
|
|
self.norm_b = None
|
|
else:
|
|
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
|
|
requires_grad=False)
|
|
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
|
|
requires_grad=False)
|
|
self.layer_past = None
|
|
self.layer_norm = LayerNormOp()
|
|
if DeepSpeedTransformerInference.workspace is None:
|
|
DeepSpeedTransformerInference.workspace = WorkspaceOp(self.config)
|
|
self._should_allocate_workspace = True
|
|
|
|
def allocate_workspace(self, size):
|
|
# Allocate memory only on first layer forward
|
|
if self.config.layer_id == 0 and self._should_allocate_workspace:
|
|
DeepSpeedTransformerInference.workspace.allocate_workspace(
|
|
self.config.hidden_size, self.config.heads, size[1], size[0], DeepSpeedTransformerInference.layer_id,
|
|
self.config.mp_size, self.config.bigscience_bloom,
|
|
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
|
|
self.config.min_out_tokens)
|
|
self._should_allocate_workspace = False
|
|
|
|
@classmethod
|
|
def reset_cache(cls):
|
|
if cls.workspace is not None:
|
|
cls.workspace.reset_cache()
|
|
|
|
def forward(
|
|
self,
|
|
input=None,
|
|
input_mask=None,
|
|
attention_mask=None,
|
|
attn_mask=None,
|
|
head_mask=None,
|
|
layer_past=None,
|
|
get_key_value=False,
|
|
get_present=False,
|
|
encoder_output=None,
|
|
enc_dec_attn_mask=None,
|
|
x=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
use_cache=False,
|
|
alibi=None,
|
|
output_attentions=False,
|
|
# TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
|
|
# This needs to be redesigned later!
|
|
layer_head_mask=None,
|
|
past_key_value=None,
|
|
**kwargs):
|
|
|
|
if x is not None:
|
|
input = x
|
|
if "hidden_states" in kwargs:
|
|
input = kwargs["hidden_states"]
|
|
|
|
if layer_past is not None and past_key_value is not None:
|
|
raise ValueError("Only one of `layer_past` or `past_key_value` can be present.")
|
|
|
|
input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
|
|
|
|
self.allocate_workspace(input.size())
|
|
|
|
get_present = (get_present or get_key_value or use_cache)
|
|
input_mask = input_mask if attention_mask is None else attention_mask
|
|
|
|
# We set the prev key/value to None when there is a prompt
|
|
if input.shape[1] > 1:
|
|
self.layer_past = None
|
|
_layer_past = layer_past or past_key_value or self.layer_past
|
|
head_mask = layer_head_mask if layer_head_mask is not None else head_mask
|
|
|
|
attn_mask = None
|
|
if isinstance(input, tuple):
|
|
attn_mask = input[1]
|
|
input = input[0]
|
|
input_type = input.dtype
|
|
|
|
if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
|
|
and input.dtype == torch.float:
|
|
target_dtype = torch.half if self.config.dtype == torch.int8 else self.config.dtype
|
|
input = input.to(target_dtype)
|
|
|
|
with torch.no_grad():
|
|
attention_output, key, value, context_outputtn_ctx, inp_norm = \
|
|
self.attention(input,
|
|
input_mask,
|
|
head_mask,
|
|
_layer_past,
|
|
get_present,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
output_attentions,
|
|
self.norm_w,
|
|
self.norm_b,
|
|
alibi,
|
|
**kwargs)
|
|
|
|
presents = (key, value)
|
|
self.layer_past = presents if layer_past is None and past_key_value is None else None
|
|
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
|
|
|
|
if not self.config.pre_layer_norm:
|
|
output = self.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
|
|
|
|
output = output.to(input_type)
|
|
if get_present:
|
|
output = (output, presents)
|
|
|
|
if self.config.return_single_tuple:
|
|
return (output, )
|
|
elif self.config.return_tuple:
|
|
return output if type(output) is tuple else (output, attn_mask)
|
|
else:
|
|
return output
|