diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index c66856223..f5a8a2caa 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -20,7 +20,7 @@ from ..comm.comm import init_distributed from ..pipe import PipelineModule from ..moe.utils import has_moe_layers from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing -from ..module_inject.replace_policy import DSPolicy +from ..module_inject.replace_policy import TransformerPolicy DS_INFERENCE_ENABLED = False from torch import nn @@ -54,7 +54,7 @@ class InferenceEngine(Module): self.generate = self._generate if hasattr(self.module, "config"): - DSPolicy.hf_model_config = self.module.config + TransformerPolicy.hf_model_config = self.module.config # todo: keep this self.injection_dict because we don't use to change config.injection_policy API # todo: this will get changed when Molly's PR on auto injection dict is merged @@ -114,7 +114,6 @@ class InferenceEngine(Module): # retain this from the old conditional argument being passed to apply_injection_policy() if not config.replace_with_kernel_inject: config.checkpoint = None - if self.injection_dict: for client_module, injection_policy in self.injection_dict.items(): # construct the tuple and pass that instead of a string or dict. diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index f0a5acdc2..b3ad67aac 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -665,6 +665,7 @@ class HFOPTLayerPolicy(TransformerPolicy): mlp_act_func_type=ActivationFuncType.ReLU, pre_attn_norm=True) self.client_module = client_module + try: import transformers HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 83c24fe67..e6415cb16 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -12,6 +12,8 @@ from transformers import pipeline from transformers.models.t5.modeling_t5 import T5Block from transformers.models.roberta.modeling_roberta import RobertaLayer from huggingface_hub import HfApi +from deepspeed.model_implementations import DeepSpeedTransformerInference +from torch import nn rocm_version = OpBuilder.installed_rocm_version() if rocm_version != (0, 0): @@ -214,6 +216,19 @@ def assert_fn(model_w_task): return assert_fn +def check_injection(model): + def verify_injection(module): + for child in module.children(): + if isinstance(child, nn.ModuleList): + assert isinstance(child[0], DeepSpeedTransformerInference),\ + "DeepSpeed-Inference Transformer kernels has not been injected in the model" + break + else: + verify_injection(child) + + verify_injection(model) + + """ Tests """ @@ -266,6 +281,7 @@ class TestModelTask(DistributedTest): replace_with_kernel_inject=True, enable_cuda_graph=enable_cuda_graph, ) + check_injection(pipe.model) # Warm-up queries for perf measurement #for i in range(10): # _ = pipe(query, **inf_kwargs) @@ -328,6 +344,7 @@ class TestMPSize(DistributedTest): dtype=dtype, replace_method="auto", replace_with_kernel_inject=True) + check_injection(pipe.model) # Switch device to GPU so that input tensors are not on CPU pipe.device = torch.device(f"cuda:{local_rank}") ds_output = pipe(query, **inf_kwargs) @@ -453,6 +470,7 @@ class TestLMCorrectness(DistributedTest): replace_with_kernel_inject=True, enable_cuda_graph=False, ) + check_injection(ds_model) setattr(lm, model_family, ds_model) torch.cuda.synchronize() start = time.time()