mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix invalid f-strings detected by ruff. --------- Signed-off-by: cyy <cyyever@outlook.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com>
724 lines
32 KiB
Python
724 lines
32 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import os
|
|
import torch
|
|
import tqdm
|
|
import deepspeed
|
|
import deepspeed.ops.transformer as transformer_inference
|
|
from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
|
|
from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
|
|
from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
|
|
from deepspeed.accelerator import get_accelerator
|
|
from .replace_policy import replace_policies, generic_policies
|
|
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
|
|
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d
|
|
from deepspeed.module_inject.layers import is_autotp_training_mode
|
|
from deepspeed import comm as dist
|
|
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size
|
|
|
|
from .load_checkpoint import load_model_with_checkpoint
|
|
import time
|
|
|
|
from .utils import policy_to_ds_container
|
|
import gc
|
|
|
|
|
|
def get_transformer_name(replaced_module):
|
|
from .containers import supported_models
|
|
from torch.nn import ModuleList
|
|
transformer_name = ''
|
|
for n, c in replaced_module.named_children():
|
|
if c.__class__ in supported_models:
|
|
transformer_name += n + '.'
|
|
for name, child in c.named_children():
|
|
if child.__class__ is ModuleList:
|
|
transformer_name += name
|
|
break
|
|
break
|
|
return transformer_name
|
|
|
|
|
|
class GroupQuantizer:
|
|
|
|
def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
|
|
self.group_size = group_size
|
|
self.num_bits = num_bits
|
|
self.q_int8 = q_int8
|
|
|
|
self.num_groups = num_groups
|
|
|
|
def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
|
|
if not self.q_int8 or not qkv:
|
|
inputs = torch.nn.Parameter(inputs, requires_grad=False)
|
|
inputs.scale = torch.empty(1)
|
|
return inputs
|
|
q_range = 2**self.num_bits
|
|
num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
|
|
inputs = inputs.to(get_accelerator().current_device_name())
|
|
input_flat = inputs.reshape(num_groups, -1).contiguous()
|
|
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
|
|
input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
|
|
scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
|
|
input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
|
|
inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
|
|
out = torch.nn.Parameter(inputs_q, requires_grad=False)
|
|
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
|
|
input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
|
|
input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
|
|
input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
|
|
scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
|
|
for i in range(2)]
|
|
|
|
out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
|
|
-1).contiguous()
|
|
return out
|
|
|
|
|
|
def _module_match(module):
|
|
for policy in generic_policies:
|
|
policy = policy()
|
|
if policy.match(module):
|
|
return policy
|
|
return None
|
|
|
|
|
|
def generic_injection(module, dtype=None, enable_cuda_graph=True):
|
|
|
|
def replace_attn(child, policy):
|
|
policy_attn = policy.attention(child)
|
|
if policy_attn is None:
|
|
return child
|
|
if len(policy_attn) == 5:
|
|
qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
|
|
qw, kw, vw = torch.empty(0), torch.empty(0), torch.empty(0)
|
|
else:
|
|
qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn
|
|
qkvw = torch.empty(0)
|
|
|
|
config = transformer_inference.DeepSpeedInferenceConfig(
|
|
hidden_size=hidden_size,
|
|
heads=heads,
|
|
dtype=dtype,
|
|
triangular_masking=False,
|
|
max_out_tokens=4096,
|
|
)
|
|
attn_module = DeepSpeedDiffusersAttention(config)
|
|
|
|
def transpose(data):
|
|
data = data.contiguous()
|
|
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
|
|
data = data.reshape(data.shape[-1], data.shape[-2])
|
|
data.to(get_accelerator().current_device_name())
|
|
return data
|
|
|
|
if len(policy_attn) == 5:
|
|
assert qkvw is not None and qkvw.data is not None, "qkvw can't be None"
|
|
attn_module.attn_qkvw.data = transpose(qkvw.data)
|
|
else:
|
|
attn_module.attn_qkvw = None
|
|
assert qw is not None and qw.data is not None, "qw can't be None"
|
|
attn_module.attn_qw.data = transpose(qw.data)
|
|
assert kw is not None and kw.data is not None, "kw can't be None"
|
|
attn_module.attn_kw.data = transpose(kw.data)
|
|
assert vw is not None and vw.data is not None, "vw can't be None"
|
|
attn_module.attn_vw.data = transpose(vw.data)
|
|
|
|
attn_module.attn_qkvb = None
|
|
attn_module.attn_ow.data = transpose(attn_ow.data)
|
|
attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
|
|
return attn_module
|
|
|
|
def replace_attn_block(child, policy):
|
|
config = Diffusers2DTransformerConfig()
|
|
return DeepSpeedDiffusersTransformerBlock(child, config)
|
|
|
|
if isinstance(module, torch.nn.Module):
|
|
pass
|
|
else:
|
|
if dtype not in [torch.float16, torch.half]:
|
|
raise ValueError("Generic injection only supported with FP16")
|
|
|
|
try:
|
|
import diffusers
|
|
if hasattr(diffusers.models.attention, 'CrossAttention'):
|
|
cross_attention = diffusers.models.attention.CrossAttention
|
|
else:
|
|
cross_attention = diffusers.models.attention_processor.Attention
|
|
attention_block = diffusers.models.attention.BasicTransformerBlock
|
|
new_policies = {
|
|
cross_attention: replace_attn,
|
|
attention_block: replace_attn_block,
|
|
}
|
|
except ImportError:
|
|
new_policies = {}
|
|
|
|
#replace_transformer_layer(None,
|
|
# module.text_encoder,
|
|
# training=False,
|
|
# replace_with_kernel_inject=True,
|
|
# triangular_masking=True,
|
|
# max_out_tokens=8192)
|
|
from ..model_implementations.transformers.clip_encoder import DSClipEncoder
|
|
cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
|
|
setattr(module, 'text_encoder', cg_encoder)
|
|
for name in module.__dict__.keys():
|
|
sub_module = getattr(module, name)
|
|
policy = _module_match(sub_module)
|
|
|
|
if policy is not None:
|
|
|
|
def _replace_module(module, policy):
|
|
for name, child in module.named_children():
|
|
_replace_module(child, policy)
|
|
if child.__class__ in new_policies:
|
|
replaced_module = new_policies[child.__class__](child, policy)
|
|
setattr(module, name, replaced_module)
|
|
|
|
_replace_module(sub_module, policy)
|
|
new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
|
|
print(f"**** found and replaced {name} w. {type(new_module)}")
|
|
setattr(module, name, new_module)
|
|
|
|
|
|
container_g = None
|
|
|
|
|
|
def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
|
|
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
|
|
Arguments:
|
|
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
|
|
e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
|
|
model (torch.nn.Module): user's nn.module representing their model
|
|
checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine
|
|
config: top-level DS Inference config defined in inference/config.py
|
|
model_config: HuggingFace model config passed from the inference/engine.py
|
|
Returns:
|
|
Updated nn.module with replaced transformer layers
|
|
"""
|
|
# defining globals as internally defined functions inherit these everywhere
|
|
quantize = (config.dtype == torch.int8)
|
|
# todo: Refactor later. In future, let's minimize the style used above and use config.** instead
|
|
|
|
linear_layer_setting = None
|
|
'''
|
|
linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers
|
|
'''
|
|
micro_batch_size = -1
|
|
seed = -1
|
|
local_rank = -1
|
|
|
|
mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
|
|
mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
|
|
|
|
def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
|
|
policy = policy_cls(child, inference=inference)
|
|
if not policy.cuda_graph_supported:
|
|
# policy says cuda graph is not supported raise an error if set
|
|
assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
|
|
|
|
from deepspeed.moe.layer import MoE
|
|
moe = False
|
|
if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
|
|
num_experts = child.mlp.num_experts
|
|
moe = True
|
|
|
|
# 1. Create a model-specific container object using the policy object.
|
|
_container = policy_to_ds_container(policy=policy,
|
|
config=config,
|
|
model_config=model_config,
|
|
layer_id=layer_id,
|
|
child=child)
|
|
_container.set_moe(moe)
|
|
|
|
# 2. Set the tensor parallelism config
|
|
_container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
|
|
|
|
# 3. Initialize tensors
|
|
_container.initialize_tensors()
|
|
|
|
# 4. deal with data types -- needs refactor to use dtype instead of fp16
|
|
if config.dtype in [torch.float16, torch.bfloat16, torch.int8]:
|
|
_container.convert_to_required_dtype()
|
|
|
|
# 5. Set the quantization config
|
|
quantizer = GroupQuantizer(q_int8=quantize)
|
|
_container.set_quantization_config(quantizer)
|
|
|
|
# 6. create a DS Inference config object
|
|
_container.create_ds_model_config()
|
|
|
|
# 7. use the config and create the module
|
|
_container.create_module()
|
|
|
|
# 8. transpose the weights and bias if needed
|
|
_container.transpose()
|
|
|
|
# 9. deal with tensor parallelism.
|
|
_container.apply_tensor_parallelism(mp_replace)
|
|
|
|
# 10. copy the tensors from the model-specific container to the new module
|
|
_container.copy_data_to_new_module()
|
|
|
|
# 11. set global for generic checkpoint loading
|
|
global container_g
|
|
|
|
if container_g is None:
|
|
container_g = _container
|
|
|
|
return _container.module
|
|
|
|
def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
|
|
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
|
|
|
|
# 1. Create AutoTP object
|
|
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
|
|
config.keep_module_on_host)
|
|
|
|
# 2. Set the tensor parallelism config
|
|
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
|
|
|
|
# 3. Try to get num_key_heads from model_config.num_key_value_heads
|
|
if hasattr(model_config, "vision_config"):
|
|
if "MllamaVisionEncoderLayer" in str(module):
|
|
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.vision_config)
|
|
elif hasattr(model_config, "text_config"):
|
|
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.text_config)
|
|
else:
|
|
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
|
|
else:
|
|
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
|
|
|
|
# 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
|
|
set_num_kv_heads(num_kv_heads)
|
|
|
|
# 4.1 Get n_embd
|
|
n_embd = None
|
|
multi_query_n_embd_names = ['n_embd', 'hidden_size']
|
|
for name in multi_query_n_embd_names:
|
|
if hasattr(model_config, name):
|
|
n_embd = getattr(model_config, name)
|
|
if n_embd != None:
|
|
break
|
|
|
|
# 4.2 set n_embd
|
|
set_n_embd(n_embd)
|
|
|
|
# 4.3 set attention_heads
|
|
if hasattr(model_config, 'num_attention_heads'):
|
|
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))
|
|
|
|
# 4.4 set tp_grain_size
|
|
set_tp_grain_size(config.tensor_parallel.tp_grain_size)
|
|
|
|
# 5. Set linear policies
|
|
_autotp.update_linear_policies()
|
|
|
|
# 6. Replace modules
|
|
if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
|
|
return _autotp._replace_last_linear_module(module)
|
|
return _autotp._replace_module(module)
|
|
|
|
def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
|
|
# copy relevant state from child -> new module
|
|
if not is_autotp_training_mode() and config.replace_with_kernel_inject:
|
|
new_module = replace_with_policy(child,
|
|
_policy,
|
|
config.triangular_masking,
|
|
inference=True,
|
|
layer_id=layer_id)
|
|
else:
|
|
new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
|
|
|
|
return new_module
|
|
|
|
def set_lm_head(module):
|
|
if is_autotp_training_mode():
|
|
# we need to handle autoTP training mode separately.
|
|
return
|
|
|
|
embedding_weight = None
|
|
for n, p in module.named_parameters():
|
|
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
|
|
embedding_weight = p
|
|
if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
|
|
module.lm_head, "weight") and module.lm_head.weight.is_meta:
|
|
module.lm_head.weight = embedding_weight
|
|
# enable tensor parallel for the last linear
|
|
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance(
|
|
module.lm_head, torch.nn.Linear):
|
|
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
|
|
elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance(
|
|
module.embed_out, torch.nn.Linear):
|
|
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
|
|
elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"):
|
|
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
|
|
return module
|
|
|
|
def conv2d_parallel_shard_weights(model, rank, world_size):
|
|
# add conv policy
|
|
shard_oc_name = ["conv1"]
|
|
shard_ic_name = ["conv2"]
|
|
for name, sub_m in model.named_children():
|
|
for l_name, l_sub_m in sub_m.named_children():
|
|
if l_name in shard_oc_name:
|
|
TPConv2d = TensorParallelOcShardConv2d(
|
|
l_sub_m,
|
|
rank,
|
|
world_size,
|
|
)
|
|
setattr(sub_m, l_name, TPConv2d)
|
|
if l_name in shard_ic_name:
|
|
TPConv2d = TensorParallelIcShardConv2d(
|
|
l_sub_m,
|
|
rank,
|
|
world_size,
|
|
)
|
|
setattr(sub_m, l_name, TPConv2d)
|
|
conv2d_parallel_shard_weights(sub_m, rank, world_size)
|
|
|
|
if checkpoint_dict is not None and not config.replace_with_kernel_inject:
|
|
# AutoTP shard loading
|
|
checkpoint = checkpoint_dict["checkpoints"]
|
|
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
|
|
for i in range(len(checkpoint)):
|
|
checkpoint_file = os.path.join(config.base_dir, checkpoint[i])
|
|
replaced_module = replace_module(model=model,
|
|
orig_class=orig_layer_impl,
|
|
replace_fn=replace_fn,
|
|
_replace_policy=config.injection_policy_tuple,
|
|
checkpoint=checkpoint_file)
|
|
pbar.update(1)
|
|
gc.collect()
|
|
# conv2d tp module replace
|
|
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
|
|
if 'Yuan' in str(replaced_module):
|
|
conv2d_parallel_shard_weights(replaced_module, dist.get_rank(), dist.get_world_size())
|
|
else:
|
|
replaced_module = replace_module(model=model,
|
|
orig_class=orig_layer_impl,
|
|
replace_fn=replace_fn,
|
|
_replace_policy=config.injection_policy_tuple)
|
|
# AutoTP default set lm_head tp
|
|
if not config.replace_with_kernel_inject:
|
|
replaced_module = set_lm_head(replaced_module)
|
|
|
|
quantizer = GroupQuantizer(q_int8=quantize)
|
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
if checkpoint_dict is not None and config.replace_with_kernel_inject:
|
|
assert container_g.ckpt_load_enabled, \
|
|
f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
|
|
start_time = time.time()
|
|
checkpoint = checkpoint_dict['checkpoints']
|
|
ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint
|
|
ckpt_type = checkpoint_dict.get('parallelization', 'pp')
|
|
ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list))
|
|
ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size)
|
|
base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
|
|
|
|
if ckpt_type == 'pp' and type(checkpoint) is list:
|
|
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
|
|
|
|
for i in range(len(checkpoint)):
|
|
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu', weights_only=False)]
|
|
load_model_with_checkpoint(replaced_module,
|
|
sd,
|
|
mp_replace,
|
|
ckpt_type,
|
|
ckpt_mp_size,
|
|
quantizer,
|
|
container=container_g)
|
|
pbar.update(1)
|
|
else:
|
|
num_checkpoints = len(ckpt_list) // ckpt_mp_size
|
|
tp_split_size = (world_size / ckpt_mp_size)
|
|
sd_offset = int(rank / tp_split_size)
|
|
sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
|
|
pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
|
|
for i in range(num_checkpoints):
|
|
pbar.update(1)
|
|
ckpt_index = i * ckpt_mp_size + sd_offset
|
|
ckpt_files = [
|
|
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
|
|
for j in range(sd_count)
|
|
]
|
|
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False) for ckpt_file in ckpt_files]
|
|
load_model_with_checkpoint(replaced_module,
|
|
sds,
|
|
mp_replace,
|
|
ckpt_type,
|
|
ckpt_mp_size,
|
|
quantizer,
|
|
int(rank % tp_split_size),
|
|
container=container_g)
|
|
sds = [None for _ in sds]
|
|
gc.collect()
|
|
|
|
if "non_tp" in checkpoint:
|
|
pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
|
|
desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
|
|
|
|
for i in range(len(checkpoint["non_tp"])):
|
|
pbar.update(1)
|
|
ckpt_file = os.path.join(base_dir1,
|
|
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
|
|
sds = [torch.load(ckpt_file, map_location='cpu', weights_only=False)]
|
|
load_model_with_checkpoint(replaced_module,
|
|
sds,
|
|
mp_replace,
|
|
ckpt_type,
|
|
ckpt_mp_size,
|
|
quantizer,
|
|
int(rank % tp_split_size),
|
|
container=container_g)
|
|
sds = [None for _ in sds]
|
|
gc.collect()
|
|
set_lm_head(replaced_module)
|
|
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
|
|
|
|
if not is_autotp_training_mode() and config.save_mp_checkpoint_path is not None:
|
|
from collections import OrderedDict
|
|
import json
|
|
num_partitions = 8
|
|
|
|
if checkpoint_dict is None:
|
|
ckpt_name = "ds_model"
|
|
try:
|
|
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
|
|
if isinstance(model, BloomForCausalLM):
|
|
ckpt_name = "bloom"
|
|
except ImportError:
|
|
ckpt_name = "ds_model"
|
|
else:
|
|
ckpt_name = checkpoint_dict['type']
|
|
if dist.is_initialized():
|
|
dist.barrier()
|
|
transformer_name = get_transformer_name(replaced_module)
|
|
non_tp_ckpt_name = 'non-tp.pt'
|
|
ckpt_files = [non_tp_ckpt_name]
|
|
os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
|
|
|
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
|
print("Saving tp-sharded checkpoints")
|
|
torch.save(
|
|
OrderedDict({
|
|
k: v
|
|
for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k
|
|
}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
|
|
|
|
dtype_reprs = {
|
|
torch.float32: 'float32',
|
|
torch.float16: 'float16',
|
|
torch.int8: 'int8',
|
|
torch.bfloat16: 'bfloat16'
|
|
}
|
|
|
|
ckpt_config = json.dumps({
|
|
'type': ckpt_name,
|
|
'base_dir': f'{config.save_mp_checkpoint_path}',
|
|
'checkpoints': {
|
|
"non_tp": ckpt_files,
|
|
"tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
|
|
},
|
|
'version': 1.0,
|
|
'parallelization': 'tp',
|
|
'tp_size': world_size,
|
|
'dtype': dtype_reprs[config.dtype]
|
|
})
|
|
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
|
|
cfg.write(ckpt_config)
|
|
|
|
rep_sd = replaced_module.state_dict()
|
|
for n, p in replaced_module.named_parameters():
|
|
if hasattr(p, 'scale'):
|
|
rep_sd[n] = [p, p.scale]
|
|
keys = list(rep_sd.keys())
|
|
partition_size = (len(keys) // num_partitions + 1)
|
|
for m in range(num_partitions):
|
|
torch.save(
|
|
OrderedDict({
|
|
k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
|
|
for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
|
|
}), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
|
|
|
|
return replaced_module
|
|
|
|
|
|
def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
|
|
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
|
|
Arguments:
|
|
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
|
|
e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
|
|
model (torch.nn.Module): user's nn.module representing their model
|
|
config (dict): model config containing hidden size, attention heads, etc.
|
|
Returns:
|
|
Updated nn.module with original bert-style transformer layers
|
|
"""
|
|
|
|
def replace_fn(child, _replace_policy, layer_id):
|
|
#from turing.nvidia_modelingpreln import BertLayer
|
|
orig_module = orig_layer_impl(config)
|
|
|
|
# copy relevant state from child -> original module
|
|
qkvw = child.attn_qkvw.data
|
|
qkvb = child.attn_qkvb.data
|
|
|
|
qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
|
|
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
|
|
|
|
orig_module.attention.self.query.weight.data = qw
|
|
orig_module.attention.self.query.bias.data = qb
|
|
orig_module.attention.self.key.weight.data = kw
|
|
orig_module.attention.self.key.bias.data = kb
|
|
orig_module.attention.self.value.weight.data = vw
|
|
orig_module.attention.self.value.bias.data = vb
|
|
|
|
orig_module.attention.output.dense.weight.data = child.attn_ow.data
|
|
orig_module.attention.output.dense.bias.data = child.attn_ob.data
|
|
|
|
attn_ln_w = child.attn_nw.data
|
|
attn_ln_b = child.attn_nb.data
|
|
if preln:
|
|
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
|
|
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
|
|
else:
|
|
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
|
|
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
|
|
|
|
inter_ff_w = child.inter_w.data
|
|
inter_ff_b = child.inter_b.data
|
|
if preln:
|
|
orig_module.intermediate.dense_act.weight.data = inter_ff_w
|
|
orig_module.intermediate.dense_act.bias.data = inter_ff_b
|
|
else:
|
|
orig_module.intermediate.dense.weight.data = inter_ff_w
|
|
orig_module.intermediate.dense.bias.data = inter_ff_b
|
|
|
|
orig_module.output.dense.weight.data = child.output_w.data
|
|
orig_module.output.dense.bias.data = child.output_b.data
|
|
|
|
transformer_ln_w = child.norm_w.data
|
|
transformer_ln_b = child.norm_b.data
|
|
if preln:
|
|
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
|
|
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
|
|
else:
|
|
orig_module.output.LayerNorm.weight.data = transformer_ln_w
|
|
orig_module.output.LayerNorm.bias.data = transformer_ln_b
|
|
return orig_module
|
|
|
|
return replace_module(model=model,
|
|
orig_class=deepspeed.DeepSpeedTransformerLayer,
|
|
replace_fn=replace_fn,
|
|
_replace_policy=None)
|
|
|
|
|
|
def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None):
|
|
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
|
|
Arguments:
|
|
model (torch.nn.Module): the model to augment
|
|
orig_class (torch.nn.Module): the module to search for
|
|
replace_fn (method): a method to convert instances of ``orig_class`` to the
|
|
desired type and return a new instance.
|
|
Returns:
|
|
A modified ``model``.
|
|
"""
|
|
sd = None
|
|
if checkpoint is not None:
|
|
if checkpoint.endswith(".safetensors"):
|
|
from safetensors.torch import load_file
|
|
sd = load_file(checkpoint)
|
|
else:
|
|
sd = torch.load(checkpoint, map_location='cpu', weights_only=False)
|
|
|
|
policy = {}
|
|
if orig_class is not None:
|
|
policy.update({orig_class: (replace_fn, _replace_policy)})
|
|
else:
|
|
for plcy in replace_policies:
|
|
# instantiate a throw-away policy in order to populate the _orig_layer_class
|
|
_ = plcy(None)
|
|
if isinstance(plcy._orig_layer_class, list):
|
|
for orig_layer_class in plcy._orig_layer_class:
|
|
policy.update({orig_layer_class: (replace_fn, plcy)})
|
|
elif plcy._orig_layer_class is not None:
|
|
policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
|
|
assert len(policy.items()) > 0,\
|
|
"No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
|
|
"You can find some samples here: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
|
|
|
|
replaced_module, _ = _replace_module(model, policy, state_dict=sd)
|
|
return replaced_module
|
|
|
|
|
|
from ..pipe import PipelineModule
|
|
|
|
import re
|
|
|
|
|
|
def skip_level_0_prefix(model, state_dict):
|
|
model = str(model)
|
|
key = re.search(r": (.*?)Model", model)
|
|
if key is None:
|
|
key = re.search(r": (.*?)Stack", model)
|
|
if key is None:
|
|
key = re.match(r"(.*?)Model", model)
|
|
# if keys start with 'model.', don't skip level 0 prefix
|
|
if state_dict is not None:
|
|
for item in state_dict.keys():
|
|
if re.match("^model[.]", item):
|
|
return False
|
|
if key is not None and key.group(1).lower() in ["bloom", "opt"]:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None):
|
|
""" Traverse model's children recursively and apply any transformations in ``policies``.
|
|
Arguments:
|
|
model (torch.nn.Module): model to augment
|
|
policies (dict): Mapping of source class to replacement function.
|
|
Returns:
|
|
Modified ``model``.
|
|
"""
|
|
for name, child in model.named_children():
|
|
if child.__class__ in policies:
|
|
replaced_module = policies[child.__class__][0](child,
|
|
policies[child.__class__][-1],
|
|
layer_id,
|
|
prefix=prefix + name,
|
|
state_dict=state_dict)
|
|
setattr(model, name, replaced_module)
|
|
if isinstance(model, PipelineModule):
|
|
assert hasattr(model, 'forward_funcs'),\
|
|
"we require pipe-module to have the list of fwd_functions"
|
|
model.forward_funcs[model.fwd_map[name]] = replaced_module
|
|
layer_id += 1
|
|
else:
|
|
checking_key = prefix + name + '.'
|
|
if Loading.is_load_module(child) and state_dict is not None:
|
|
if any(checking_key in item for item in state_dict):
|
|
Loading.load(
|
|
child,
|
|
state_dict,
|
|
checking_key,
|
|
)
|
|
else:
|
|
continue
|
|
if len(child._buffers) != 0 and state_dict is not None:
|
|
Loading.load_buffer(child, state_dict, checking_key)
|
|
_, layer_id = _replace_module(child,
|
|
policies,
|
|
prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \
|
|
prefix + name + '.',
|
|
layer_id=layer_id,
|
|
level_id=level_id + 1,
|
|
state_dict=state_dict)
|
|
|
|
# Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
|
|
model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
|
|
return model, layer_id
|