Add arctic model support by adding w2 to all_reduce (#6856)

As title says. 

Default behavior of arctic model produces shape issues with AutoTP due
to the MLP layer performing `w2 * act(w1*w3)`. However, method provided
to fix Mixtral-7x8b in #5257 does not work since the MLP for Arctic is
also used within a ModuleList for the MoE. This results in MLP weights
hiding behind individual experts as layers `#.w#`, which is not caught
by the fix in #5257. This adds the check directly within replace, where
it can check for actual layer names for the `w2` key in the model to
patch with `all_reduce`.

---------

Signed-off-by: Daniel Huang <daniel1.huang@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Daniel Huang
2024-12-18 08:09:31 -08:00
committed by GitHub
parent 4cd1d97460
commit 0b25630abe
2 changed files with 6 additions and 1 deletions

View File

@ -346,11 +346,15 @@ class AutoTP():
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

View File

@ -121,6 +121,7 @@ The following results were collected using V100 SXM2 32GB GPUs.
The following model families have been successfully tested with automatic tensor parallelism. Other models may work but have not been tested yet.
- albert
- arctic
- baichuan
- bert
- bigbird_pegasus