Fix Opt injection (#2541)

* fix Opt injection & add injection verification check at inference test

* fix several issues

* remove fixture

* remove check_injection when no kerenl is injected

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
Reza Yazdani
2023-01-06 13:21:49 -08:00
committed by GitHub
parent a091bc223c
commit 95d9a1b6c3
3 changed files with 21 additions and 3 deletions

View File

@ -20,7 +20,7 @@ from ..comm.comm import init_distributed
from ..pipe import PipelineModule
from ..moe.utils import has_moe_layers
from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
from ..module_inject.replace_policy import DSPolicy
from ..module_inject.replace_policy import TransformerPolicy
DS_INFERENCE_ENABLED = False
from torch import nn
@ -54,7 +54,7 @@ class InferenceEngine(Module):
self.generate = self._generate
if hasattr(self.module, "config"):
DSPolicy.hf_model_config = self.module.config
TransformerPolicy.hf_model_config = self.module.config
# todo: keep this self.injection_dict because we don't use to change config.injection_policy API
# todo: this will get changed when Molly's PR on auto injection dict is merged
@ -114,7 +114,6 @@ class InferenceEngine(Module):
# retain this from the old conditional argument being passed to apply_injection_policy()
if not config.replace_with_kernel_inject:
config.checkpoint = None
if self.injection_dict:
for client_module, injection_policy in self.injection_dict.items():
# construct the tuple and pass that instead of a string or dict.

View File

@ -665,6 +665,7 @@ class HFOPTLayerPolicy(TransformerPolicy):
mlp_act_func_type=ActivationFuncType.ReLU,
pre_attn_norm=True)
self.client_module = client_module
try:
import transformers
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer

View File

@ -12,6 +12,8 @@ from transformers import pipeline
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.roberta.modeling_roberta import RobertaLayer
from huggingface_hub import HfApi
from deepspeed.model_implementations import DeepSpeedTransformerInference
from torch import nn
rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
@ -214,6 +216,19 @@ def assert_fn(model_w_task):
return assert_fn
def check_injection(model):
def verify_injection(module):
for child in module.children():
if isinstance(child, nn.ModuleList):
assert isinstance(child[0], DeepSpeedTransformerInference),\
"DeepSpeed-Inference Transformer kernels has not been injected in the model"
break
else:
verify_injection(child)
verify_injection(model)
"""
Tests
"""
@ -266,6 +281,7 @@ class TestModelTask(DistributedTest):
replace_with_kernel_inject=True,
enable_cuda_graph=enable_cuda_graph,
)
check_injection(pipe.model)
# Warm-up queries for perf measurement
#for i in range(10):
# _ = pipe(query, **inf_kwargs)
@ -328,6 +344,7 @@ class TestMPSize(DistributedTest):
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True)
check_injection(pipe.model)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs)
@ -453,6 +470,7 @@ class TestLMCorrectness(DistributedTest):
replace_with_kernel_inject=True,
enable_cuda_graph=False,
)
check_injection(ds_model)
setattr(lm, model_family, ds_model)
torch.cuda.synchronize()
start = time.time()