mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:46:02 +08:00
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Co-authored-by: yaozhewei <zheweiy@berkeley.edu> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Connor Holmes <connorholmes@microsoft.com> Co-authored-by: Lok Chand Koppaka <lokoppak@microsoft.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
from .base import *
|
|
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
|
|
import torch
|
|
from torch.nn.parameter import Parameter
|
|
from ..policy import TransformerPolicy
|
|
|
|
|
|
class DS_CLIPContainer(BaseTransformerContainer):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
# All model specific things should be defined here instead of the base class.
|
|
|
|
def create_module(self, config=None):
|
|
_config = config if config is not None else self.ds_model_config
|
|
self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
|
|
self.module.config.scale_attention = self.scale_attention
|
|
return self.module
|
|
|
|
|
|
class HFCLIPLayerPolicy(TransformerPolicy):
|
|
|
|
def __init__(self, client_module, inference=False):
|
|
super().__init__(inference, pre_attn_norm=True, scale_attention=True)
|
|
self.client_module = client_module
|
|
self.cuda_graph_supported = True
|
|
|
|
if HFCLIPLayerPolicy._orig_layer_class is None:
|
|
try:
|
|
import transformers
|
|
HFCLIPLayerPolicy._orig_layer_class = transformers.models.clip.modeling_clip.CLIPEncoderLayer
|
|
except:
|
|
HFCLIPLayerPolicy._orig_layer_class = None
|
|
|
|
def get_hidden_heads(self):
|
|
return self.client_module.self_attn.q_proj.weight.shape[1], \
|
|
self.client_module.self_attn.num_heads, \
|
|
self.client_module.layer_norm1.eps
|
|
|
|
def get_q_k_v(self):
|
|
return None
|
|
|
|
def attention(self):
|
|
qw = self.client_module.self_attn.q_proj.weight
|
|
qb = self.client_module.self_attn.q_proj.bias
|
|
kw = self.client_module.self_attn.k_proj.weight
|
|
kb = self.client_module.self_attn.k_proj.bias
|
|
vw = self.client_module.self_attn.v_proj.weight
|
|
vb = self.client_module.self_attn.v_proj.bias
|
|
|
|
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
|
|
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)
|
|
|
|
return qkvw, \
|
|
qkvb, \
|
|
self.client_module.self_attn.out_proj.weight, \
|
|
self.client_module.self_attn.out_proj.bias
|
|
|
|
def mlp(self):
|
|
return self.client_module.mlp.fc1.weight, \
|
|
self.client_module.mlp.fc1.bias, \
|
|
self.client_module.mlp.fc2.weight, \
|
|
self.client_module.mlp.fc2.bias
|
|
|
|
def layernorm(self):
|
|
return self.client_module.layer_norm2.weight, \
|
|
self.client_module.layer_norm2.bias, \
|
|
self.client_module.layer_norm1.weight, \
|
|
self.client_module.layer_norm1.bias
|
|
|
|
def get_lora_params(self):
|
|
return []
|