diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 5e0aa0286..3429ceb0a 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -307,7 +307,7 @@ class AutoTP(): # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce. elif 'w2' in layer and 'Mixtral' in str(type(module)): gem_list = gem_list + [layer] - elif "self_attn.dense" in layer and "Phi" in str(type(module)): + elif 'self_attn.dense' in layer and 'Phi' in str(type(module)): gem_list = gem_list + [layer] layer_list = []