mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix Opt injection (#2541)
* fix Opt injection & add injection verification check at inference test * fix several issues * remove fixture * remove check_injection when no kerenl is injected Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
@ -20,7 +20,7 @@ from ..comm.comm import init_distributed
|
||||
from ..pipe import PipelineModule
|
||||
from ..moe.utils import has_moe_layers
|
||||
from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
|
||||
from ..module_inject.replace_policy import DSPolicy
|
||||
from ..module_inject.replace_policy import TransformerPolicy
|
||||
|
||||
DS_INFERENCE_ENABLED = False
|
||||
from torch import nn
|
||||
@ -54,7 +54,7 @@ class InferenceEngine(Module):
|
||||
self.generate = self._generate
|
||||
|
||||
if hasattr(self.module, "config"):
|
||||
DSPolicy.hf_model_config = self.module.config
|
||||
TransformerPolicy.hf_model_config = self.module.config
|
||||
|
||||
# todo: keep this self.injection_dict because we don't use to change config.injection_policy API
|
||||
# todo: this will get changed when Molly's PR on auto injection dict is merged
|
||||
@ -114,7 +114,6 @@ class InferenceEngine(Module):
|
||||
# retain this from the old conditional argument being passed to apply_injection_policy()
|
||||
if not config.replace_with_kernel_inject:
|
||||
config.checkpoint = None
|
||||
|
||||
if self.injection_dict:
|
||||
for client_module, injection_policy in self.injection_dict.items():
|
||||
# construct the tuple and pass that instead of a string or dict.
|
||||
|
@ -665,6 +665,7 @@ class HFOPTLayerPolicy(TransformerPolicy):
|
||||
mlp_act_func_type=ActivationFuncType.ReLU,
|
||||
pre_attn_norm=True)
|
||||
self.client_module = client_module
|
||||
|
||||
try:
|
||||
import transformers
|
||||
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
|
||||
|
@ -12,6 +12,8 @@ from transformers import pipeline
|
||||
from transformers.models.t5.modeling_t5 import T5Block
|
||||
from transformers.models.roberta.modeling_roberta import RobertaLayer
|
||||
from huggingface_hub import HfApi
|
||||
from deepspeed.model_implementations import DeepSpeedTransformerInference
|
||||
from torch import nn
|
||||
|
||||
rocm_version = OpBuilder.installed_rocm_version()
|
||||
if rocm_version != (0, 0):
|
||||
@ -214,6 +216,19 @@ def assert_fn(model_w_task):
|
||||
return assert_fn
|
||||
|
||||
|
||||
def check_injection(model):
|
||||
def verify_injection(module):
|
||||
for child in module.children():
|
||||
if isinstance(child, nn.ModuleList):
|
||||
assert isinstance(child[0], DeepSpeedTransformerInference),\
|
||||
"DeepSpeed-Inference Transformer kernels has not been injected in the model"
|
||||
break
|
||||
else:
|
||||
verify_injection(child)
|
||||
|
||||
verify_injection(model)
|
||||
|
||||
|
||||
"""
|
||||
Tests
|
||||
"""
|
||||
@ -266,6 +281,7 @@ class TestModelTask(DistributedTest):
|
||||
replace_with_kernel_inject=True,
|
||||
enable_cuda_graph=enable_cuda_graph,
|
||||
)
|
||||
check_injection(pipe.model)
|
||||
# Warm-up queries for perf measurement
|
||||
#for i in range(10):
|
||||
# _ = pipe(query, **inf_kwargs)
|
||||
@ -328,6 +344,7 @@ class TestMPSize(DistributedTest):
|
||||
dtype=dtype,
|
||||
replace_method="auto",
|
||||
replace_with_kernel_inject=True)
|
||||
check_injection(pipe.model)
|
||||
# Switch device to GPU so that input tensors are not on CPU
|
||||
pipe.device = torch.device(f"cuda:{local_rank}")
|
||||
ds_output = pipe(query, **inf_kwargs)
|
||||
@ -453,6 +470,7 @@ class TestLMCorrectness(DistributedTest):
|
||||
replace_with_kernel_inject=True,
|
||||
enable_cuda_graph=False,
|
||||
)
|
||||
check_injection(ds_model)
|
||||
setattr(lm, model_family, ds_model)
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
Reference in New Issue
Block a user