Files
DeepSpeed/deepspeed/module_inject/containers/gptj.py
Olatunji Ruwase 47f9f13bd3 DeepSpeed Chat (#3186)
Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: yaozhewei <zheweiy@berkeley.edu>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Connor Holmes <connorholmes@microsoft.com>
Co-authored-by: Lok Chand Koppaka <lokoppak@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
2023-04-11 11:53:38 -07:00

106 lines
3.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import *
from .features.meta_tensor import MetaTensorContainer
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
import torch
from torch.nn.parameter import Parameter
from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
from ..policy import maybe_copy_qkv
class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# All model specific things should be defined here instead of the base class.
def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
self.module.config.scale_attention = self.scale_attention
return self.module
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'attn.q_proj.weight', \
'attn.k_proj.weight', \
'attn.v_proj.weight', \
'attn.out_proj.weight', \
'mlp.fc_in.weight', \
'mlp.fc_in.bias', \
'mlp.fc_out.weight', \
'mlp.fc_out.bias', \
'ln_1.weight', \
'ln_1.bias'
)
maybe_copy_qkv(module.attention,
sd,
weight_quantizer,
mp_replace,
'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 4):
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
prefix + param_names[i])
for i in range(4, 8):
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(8, 10):
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i + 2],
prefix + param_names[i])
class HFGPTJLayerPolicy(TransformerPolicy):
_orig_layer_class = None
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
try:
import transformers
HFGPTJLayerPolicy._orig_layer_class = transformers.models.gptj.modeling_gptj.GPTJBlock
except:
HFGPTJLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.shape[1], \
self.client_module.attn.num_attention_heads, \
self.client_module.ln_1.eps
def get_q_k_v(self):
return None
def attention(self, enable_training=False):
qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
return qkvw, \
None, \
self.client_module.attn.out_proj.weight, \
None,
def mlp(self):
return self.client_module.mlp.fc_in.weight, \
self.client_module.mlp.fc_in.bias, \
self.client_module.mlp.fc_out.weight, \
self.client_module.mlp.fc_out.bias
def layernorm(self):
return None, \
None, \
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
def get_lora_params(self):
return []