Files
DeepSpeed/deepspeed/model_implementations/transformers/ds_transformer.py
Yuanyuan Chen 1c03d1b1bb Fix invalid f-strings (#7457)
Fix invalid f-strings detected by ruff.

---------

Signed-off-by: cyy <cyyever@outlook.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com>
2025-08-16 18:22:19 +00:00

195 lines
8.6 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import torch.nn as nn
from deepspeed import comm as dist
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
from deepspeed.utils.logging import log_dist
from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp
from deepspeed.accelerator import get_accelerator
import deepspeed
if deepspeed.HAS_TRITON and get_accelerator().is_triton_supported():
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention
class DeepSpeedTransformerInference(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Arguments:
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
layer_id will be 0,1,2...23 when each layer object is instantiated
config: An object of DeepSpeedInferenceConfig
mp_group: Model parallelism group initialized on the modeling side.
quantize_scales: This argument groups all the layers' scales used for quantization
quantize_groups: Number of groups used for quantizing the model
merge_count: Shows the number of model-parallel checkpoints merged before running inference.
We use this argument to control the quantization scale for the model parameters if a bigger
quantize-grouping than 1 is used.
mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
of a Transformer layer. We use this feature for quantization to reduce the convergence impact
for specific downstream tasks.
"""
layer_id = 0
workspace = None
def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super(DeepSpeedTransformerInference, self).__init__()
self.config = config
self.config.layer_id = DeepSpeedTransformerInference.layer_id
DeepSpeedTransformerInference.layer_id += 1
data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if deepspeed.HAS_TRITON and self.config.use_triton:
log_dist("Injecting Triton kernels ...", [0])
if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
assert not self.config.use_triton
else:
if deepspeed.HAS_TRITON and self.config.use_triton:
self.attention = TritonSelfAttention(self.config)
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)
if deepspeed.HAS_TRITON and self.config.use_triton:
self.mlp = TritonMLP(self.config)
else:
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)
device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
if self.config.set_empty_params:
self.norm_w = None
self.norm_b = None
else:
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.layer_past = None
self.layer_norm = LayerNormOp()
if DeepSpeedTransformerInference.workspace is None:
DeepSpeedTransformerInference.workspace = WorkspaceOp(self.config)
self._should_allocate_workspace = True
def allocate_workspace(self, size):
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._should_allocate_workspace:
DeepSpeedTransformerInference.workspace.allocate_workspace(
self.config.hidden_size, self.config.heads, size[1], size[0], DeepSpeedTransformerInference.layer_id,
self.config.mp_size, self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._should_allocate_workspace = False
@classmethod
def reset_cache(cls):
if cls.workspace is not None:
cls.workspace.reset_cache()
def forward(
self,
input=None,
input_mask=None,
attention_mask=None,
attn_mask=None,
head_mask=None,
layer_past=None,
get_key_value=False,
get_present=False,
encoder_output=None,
enc_dec_attn_mask=None,
x=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
alibi=None,
output_attentions=False,
# TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
# This needs to be redesigned later!
layer_head_mask=None,
past_key_value=None,
**kwargs):
if x is not None:
input = x
if "hidden_states" in kwargs:
input = kwargs["hidden_states"]
if layer_past is not None and past_key_value is not None:
raise ValueError("Only one of `layer_past` or `past_key_value` can be present.")
input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
self.allocate_workspace(input.size())
get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask
# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
self.layer_past = None
_layer_past = layer_past or past_key_value or self.layer_past
head_mask = layer_head_mask if layer_head_mask is not None else head_mask
attn_mask = None
if isinstance(input, tuple):
attn_mask = input[1]
input = input[0]
input_type = input.dtype
if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
and input.dtype == torch.float:
target_dtype = torch.half if self.config.dtype == torch.int8 else self.config.dtype
input = input.to(target_dtype)
with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input,
input_mask,
head_mask,
_layer_past,
get_present,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
self.norm_w,
self.norm_b,
alibi,
**kwargs)
presents = (key, value)
self.layer_past = presents if layer_past is None and past_key_value is None else None
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
if not self.config.pre_layer_norm:
output = self.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
output = output.to(input_type)
if get_present:
output = (output, presents)
if self.config.return_single_tuple:
return (output, )
elif self.config.return_tuple:
return output if type(output) is tuple else (output, attn_mask)
else:
return output