!226 update sequence parallelism related

Merge pull request !226 from 金勇旭/update-sp
This commit is contained in:
金勇旭
2025-05-30 03:32:32 +00:00
committed by i-robot
parent 2d76b0854a
commit a82627330b
2 changed files with 17 additions and 8 deletions

View File

@ -336,7 +336,7 @@ def get_model():
**init_kwargs,
)
apply_sequence_parallel(args, config.num_attention_heads)
apply_sequence_parallel(args, config)
model = apply_adapter(model, args.do_train)
if args.do_train:

View File

@ -17,8 +17,11 @@ import torch
import torch.distributed as dist
from openmind.flow.model.sequence_parallel.ulysses import UlyssesAttention
from openmind.integrations.transformers.npu_fused_ops.sdk import SUPPORTED_FUSED_MODELS
from openmind.utils import logging
_SEQUENCE_PARALLEL_GROUP = None
logger = logging.get_logger(__name__)
class DistributedTrainingModule:
@ -89,15 +92,21 @@ def new_attn_forward(
return attn_output
def apply_sequence_parallel(args, num_head):
if num_head % args.sequence_parallel_size:
raise ValueError(
"num_attention_head must be divisible by sequence_parallel_size for sequence parallel training."
f"{num_head} can not be devisible by {args.sequence_parallel_size}"
)
def apply_sequence_parallel(args, config):
if args.sequence_parallel_size > 1:
if config.num_attention_heads % args.sequence_parallel_size:
raise ValueError(
"num_attention_head must be divisible by sequence_parallel_size for sequence parallel training."
f"{config.num_attention_heads} can not be devisible by {args.sequence_parallel_size}"
)
if not (config.architectures and config.architectures[0] in SUPPORTED_FUSED_MODELS and args.do_train):
raise ValueError(
"Sequence parallel trainning does not support models that cannot enable npu fused options."
)
group_this = DistributedTrainingModule.get_sequence_parallel_group()
original_attn = torch.nn.functional.scaled_dot_product_attention
new_attention_forward = partial(new_attn_forward, group=group_this, attn_fn=original_attn)
torch.nn.functional.scaled_dot_product_attention = new_attention_forward
logger.info_rank0("Enable sequence parallel training for the model.")