!226 update sequence parallelism related
Merge pull request !226 from 金勇旭/update-sp
This commit is contained in:
@ -336,7 +336,7 @@ def get_model():
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
apply_sequence_parallel(args, config.num_attention_heads)
|
||||
apply_sequence_parallel(args, config)
|
||||
model = apply_adapter(model, args.do_train)
|
||||
|
||||
if args.do_train:
|
||||
|
@ -17,8 +17,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from openmind.flow.model.sequence_parallel.ulysses import UlyssesAttention
|
||||
from openmind.integrations.transformers.npu_fused_ops.sdk import SUPPORTED_FUSED_MODELS
|
||||
from openmind.utils import logging
|
||||
|
||||
_SEQUENCE_PARALLEL_GROUP = None
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedTrainingModule:
|
||||
@ -89,15 +92,21 @@ def new_attn_forward(
|
||||
return attn_output
|
||||
|
||||
|
||||
def apply_sequence_parallel(args, num_head):
|
||||
if num_head % args.sequence_parallel_size:
|
||||
raise ValueError(
|
||||
"num_attention_head must be divisible by sequence_parallel_size for sequence parallel training."
|
||||
f"{num_head} can not be devisible by {args.sequence_parallel_size}"
|
||||
)
|
||||
|
||||
def apply_sequence_parallel(args, config):
|
||||
if args.sequence_parallel_size > 1:
|
||||
if config.num_attention_heads % args.sequence_parallel_size:
|
||||
raise ValueError(
|
||||
"num_attention_head must be divisible by sequence_parallel_size for sequence parallel training."
|
||||
f"{config.num_attention_heads} can not be devisible by {args.sequence_parallel_size}"
|
||||
)
|
||||
|
||||
if not (config.architectures and config.architectures[0] in SUPPORTED_FUSED_MODELS and args.do_train):
|
||||
raise ValueError(
|
||||
"Sequence parallel trainning does not support models that cannot enable npu fused options."
|
||||
)
|
||||
|
||||
group_this = DistributedTrainingModule.get_sequence_parallel_group()
|
||||
original_attn = torch.nn.functional.scaled_dot_product_attention
|
||||
new_attention_forward = partial(new_attn_forward, group=group_this, attn_fn=original_attn)
|
||||
torch.nn.functional.scaled_dot_product_attention = new_attention_forward
|
||||
logger.info_rank0("Enable sequence parallel training for the model.")
|
||||
|
Reference in New Issue
Block a user