!216 support sequence parallel algorithm
Merge pull request !216 from 金勇旭/sp
This commit is contained in:
@ -231,6 +231,24 @@ def validate_args(args):
|
||||
elif args.max_length is None and not args.do_train:
|
||||
args.max_length = 1024
|
||||
|
||||
if args.sequence_parallel_size < 1:
|
||||
raise ValueError(
|
||||
f"sequence_parallel_size must be greater than 0. Received value: {args.sequence_parallel_size}."
|
||||
)
|
||||
if args.sequence_parallel_size > 1:
|
||||
if args.stage != Stages.SFT:
|
||||
raise ValueError(
|
||||
f"Currently, sequence parallelism only support stage {Stages.SFT}. If you want to use sequence parallelism, you must specify stage as {Stages.SFT}."
|
||||
)
|
||||
if not args.use_npu_fusion_attention:
|
||||
raise ValueError("use_npu_fusion_attention must be specified for sequence parallel training.")
|
||||
if args.max_length is None:
|
||||
raise ValueError("max_length must be specified for sequence parallel training.")
|
||||
if args.max_length % (args.sequence_parallel_size * 8):
|
||||
raise ValueError(
|
||||
f"max_length must be divisible by `sequence_parallel_size * 8` for sequence parallel training. Received value {args.max_length} can not be divisible by `{args.sequence_parallel_size} * 8`."
|
||||
)
|
||||
|
||||
# When ASCEND_RT_VISIBLE_DEVICES is set to "5,6", the process will recognize devices 5 and 6 as device 0 and 1 during runtime
|
||||
# set args.device to str as device: 0 in yaml file will be parsed as int
|
||||
|
||||
@ -487,6 +505,9 @@ def _add_model_args(parser):
|
||||
)
|
||||
group.add_argument("--print_param_status", type=str2bool, default=False, help="Print model parameters status.")
|
||||
group.add_argument("--offload_folder", type=str, default=None, help="Path to offload model weights.")
|
||||
group.add_argument(
|
||||
"--sequence_parallel_size", type=int, default=1, help="Number of NPUs to preocess dataset sequence."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -200,8 +200,11 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
require_position_ids: bool = False
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
if not self.require_position_ids:
|
||||
features = [{k: v for k, v in feature.items() if k != "position_ids"} for feature in features]
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
@ -14,7 +14,7 @@
|
||||
import os
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Optional, List, Dict, Literal
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
@ -32,6 +32,7 @@ from openmind.flow.datasets.preprocess import (
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.datasets.template import Template
|
||||
from openmind.flow.datasets.parser import get_dataset_attr
|
||||
from openmind.flow.datasets.preprocess.sequence_parallel import pad_sequence, sp_split
|
||||
from openmind.utils.loader_utils import get_platform_loader
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -128,26 +129,6 @@ def _get_merged_datasets(dataset_names: Optional[str]):
|
||||
return merge_datasets(aligned_datasets)
|
||||
|
||||
|
||||
def get_dataset_module(
|
||||
tokenizer: AutoTokenizer,
|
||||
template: Template,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
):
|
||||
args = get_args()
|
||||
|
||||
with args.hf_seq2seq_args.main_process_first(desc="load dataset"):
|
||||
train_dataset = _get_merged_datasets(args.dataset)
|
||||
eval_dataset = _get_merged_datasets(args.eval_dataset)
|
||||
|
||||
logger.debug("Finish load data, train_dataset = {}, eval_dataset = {}".format(train_dataset, eval_dataset))
|
||||
|
||||
with args.hf_seq2seq_args.main_process_first(desc="preprocess dataset"):
|
||||
train_dataset = _get_preprocessed_dataset(train_dataset, template, tokenizer, processor)
|
||||
eval_dataset = _get_preprocessed_dataset(eval_dataset, template, tokenizer, processor)
|
||||
dataset_module = {"train_dataset": train_dataset, "eval_dataset": eval_dataset}
|
||||
return dataset_module
|
||||
|
||||
|
||||
def _get_preprocessed_dataset(
|
||||
dataset: Dataset,
|
||||
template: Template,
|
||||
@ -206,3 +187,70 @@ def _get_preprocess_func(template, tokenizer, processor):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return preprocess_func
|
||||
|
||||
|
||||
def _get_sequence_parallel_dataset(
|
||||
dataset: Dataset,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
args = get_args()
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
kwargs = dict(
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
load_from_cache_file=args.local_process_index != 0,
|
||||
desc="Running padding split on dataset",
|
||||
)
|
||||
pad_sequence_func = _get_sequence_parallel_func(stage="pad", tokenizer=tokenizer)
|
||||
padded_dataset = dataset.map(pad_sequence_func, batched=True, batch_size=args.preprocessing_batch_size, **kwargs)
|
||||
kwargs = dict(
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
load_from_cache_file=args.local_process_index != 0,
|
||||
desc="Running sequence parallel split on dataset",
|
||||
)
|
||||
sp_dataset_func = _get_sequence_parallel_func(stage="split", tokenizer=tokenizer)
|
||||
sp_dataset = padded_dataset.map(sp_dataset_func, batched=True, batch_size=args.preprocessing_batch_size, **kwargs)
|
||||
return sp_dataset
|
||||
|
||||
|
||||
def _get_sequence_parallel_func(
|
||||
stage: Literal["pad", "split"],
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
args = get_args()
|
||||
|
||||
if stage == "pad":
|
||||
preprocess_func = partial(pad_sequence, args=args, tokenizer=tokenizer)
|
||||
elif stage == "split":
|
||||
preprocess_func = partial(sp_split, args=args)
|
||||
|
||||
return preprocess_func
|
||||
|
||||
|
||||
def get_dataset_module(
|
||||
tokenizer: AutoTokenizer,
|
||||
template: Template,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
):
|
||||
args = get_args()
|
||||
|
||||
with args.hf_seq2seq_args.main_process_first(desc="load dataset"):
|
||||
train_dataset = _get_merged_datasets(args.dataset)
|
||||
eval_dataset = _get_merged_datasets(args.eval_dataset)
|
||||
|
||||
logger.debug("Finish load data, train_dataset = {}, eval_dataset = {}".format(train_dataset, eval_dataset))
|
||||
|
||||
with args.hf_seq2seq_args.main_process_first(desc="preprocess dataset"):
|
||||
train_dataset = _get_preprocessed_dataset(train_dataset, template, tokenizer, processor)
|
||||
eval_dataset = _get_preprocessed_dataset(eval_dataset, template, tokenizer, processor)
|
||||
|
||||
if args.sequence_parallel_size > 1:
|
||||
with args.hf_seq2seq_args.main_process_first(desc="preprocess dataset"):
|
||||
train_dataset = _get_sequence_parallel_dataset(train_dataset, tokenizer)
|
||||
eval_dataset = _get_sequence_parallel_dataset(eval_dataset, tokenizer)
|
||||
|
||||
dataset_module = {"train_dataset": train_dataset, "eval_dataset": eval_dataset}
|
||||
|
||||
return dataset_module
|
||||
|
63
src/openmind/flow/datasets/preprocess/sequence_parallel.py
Normal file
63
src/openmind/flow/datasets/preprocess/sequence_parallel.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
from openmind.utils.constants import IGNORE_INDEX
|
||||
|
||||
|
||||
def preprocess_sp_dataset(seq_ids, world_size):
|
||||
step = len(seq_ids) // world_size
|
||||
local_values = [seq_ids[s : s + step] for s in range(0, len(seq_ids), step)]
|
||||
return local_values
|
||||
|
||||
|
||||
def pad_sequence(examples, args, tokenizer):
|
||||
max_length = args.max_length
|
||||
input_pad_token_id = tokenizer.pad_token_id
|
||||
label_pad_token_id = IGNORE_INDEX if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
|
||||
for k, v in examples.items():
|
||||
if k.endswith("input_ids"):
|
||||
pad_token_id = input_pad_token_id
|
||||
elif k.endswith("labels"):
|
||||
pad_token_id = label_pad_token_id
|
||||
# shift labels here
|
||||
for i in range(len(v)):
|
||||
v[i] = v[i][1:]
|
||||
elif k.endswith("attention_mask"):
|
||||
pad_token_id = 0
|
||||
elif k.endswith("position_ids"):
|
||||
pad_token_id = max_length - 1 # pad the max position id
|
||||
elif k == "images" or k == "videos" or k == "audios":
|
||||
pad_token_id = -1
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected dataset key: {k}")
|
||||
for i in range(len(v)):
|
||||
v[i].extend([pad_token_id] * (max_length - len(v[i])))
|
||||
examples[k] = v
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def sp_split(examples, args):
|
||||
for k, v in examples.items():
|
||||
chunks = list()
|
||||
for row in v:
|
||||
if k.endswith("attention_mask"):
|
||||
chunks.extend([row] * args.sequence_parallel_size)
|
||||
elif row is None:
|
||||
chunks.extend([None] * args.sequence_parallel_size)
|
||||
else:
|
||||
chunks.extend(preprocess_sp_dataset(row, args.sequence_parallel_size))
|
||||
examples[k] = chunks
|
||||
return examples
|
@ -70,7 +70,15 @@ def preprocess_supervised_dataset(
|
||||
Returns:
|
||||
Processed data.
|
||||
"""
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "images": [], "videos": [], "audios": []}
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"videos": [],
|
||||
"audios": [],
|
||||
}
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
@ -92,6 +100,7 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["position_ids"].append(list(range(len(input_ids))))
|
||||
model_inputs["images"].append(examples["images"][i])
|
||||
model_inputs["videos"].append(examples["videos"][i])
|
||||
model_inputs["audios"].append(examples["audios"][i])
|
||||
|
@ -41,6 +41,7 @@ from openmind.integrations.transformers.npu_fused_ops.sdk import SUPPORTED_FUSED
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.model.model_registry import SUPPORTED_MODELS
|
||||
from openmind.flow.model.adapter import apply_adapter
|
||||
from openmind.flow.model.sequence_parallel.seq_utils import apply_sequence_parallel
|
||||
from openmind.integrations.transformers.bitsandbytes import patch_bnb
|
||||
from openmind.utils.loader_utils import get_platform_loader
|
||||
|
||||
@ -335,6 +336,7 @@ def get_model():
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
apply_sequence_parallel(args, config.num_attention_heads)
|
||||
model = apply_adapter(model, args.do_train)
|
||||
|
||||
if args.do_train:
|
||||
|
113
src/openmind/flow/model/sequence_parallel/seq_comm.py
Normal file
113
src/openmind/flow/model/sequence_parallel/seq_comm.py
Normal file
@ -0,0 +1,113 @@
|
||||
# Copyright (c) Microsoft Corporation and Jiarui Fang
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.s
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# modified from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/comm/all_to_all.py
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def all_to_all_4D(hidden_states: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
|
||||
"""
|
||||
all-to-all for QKV
|
||||
|
||||
Args:
|
||||
hidden_states (torch.tensor): a tensor sharded along dim scatter dim
|
||||
scatter_idx (int): default 1
|
||||
gather_idx (int): default 2
|
||||
group : torch process group
|
||||
use_sync (bool): whether to synchronize after all-to-all
|
||||
|
||||
Returns:
|
||||
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
|
||||
"""
|
||||
if hidden_states.dim() != 4:
|
||||
raise ValueError(f"hidden_states must be 4D tensor, got {hidden_states.dim()} and shape {hidden_states.shape}")
|
||||
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
if scatter_idx == 2 and gather_idx == 1:
|
||||
# hidden_states (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
|
||||
bs, shard_seqlen, hc, hs = hidden_states.shape
|
||||
seqlen = shard_seqlen * seq_world_size
|
||||
shard_hc = hc // seq_world_size
|
||||
|
||||
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
|
||||
hidden_states_t = (
|
||||
hidden_states.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous()
|
||||
)
|
||||
|
||||
output = torch.empty_like(hidden_states_t)
|
||||
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
|
||||
if seq_world_size > 1:
|
||||
dist.all_to_all_single(output, hidden_states_t, group=group)
|
||||
else:
|
||||
output = hidden_states_t
|
||||
# if scattering the seq-dim, transpose the heads back to the original dimension
|
||||
output = output.reshape(seqlen, bs, shard_hc, hs)
|
||||
|
||||
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
|
||||
output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
|
||||
|
||||
return output
|
||||
|
||||
elif scatter_idx == 1 and gather_idx == 2:
|
||||
# hidden_states (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
|
||||
bs, seqlen, shard_hc, hs = hidden_states.shape
|
||||
hc = shard_hc * seq_world_size
|
||||
shard_seqlen = seqlen // seq_world_size
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
|
||||
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
|
||||
hidden_states_t = (
|
||||
hidden_states.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs)
|
||||
.transpose(0, 3)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs)
|
||||
)
|
||||
|
||||
output = torch.empty_like(hidden_states_t)
|
||||
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
|
||||
if seq_world_size > 1:
|
||||
dist.all_to_all_single(output, hidden_states_t, group=group)
|
||||
else:
|
||||
output = hidden_states_t
|
||||
|
||||
# if scattering the seq-dim, transpose the heads back to the original dimension
|
||||
output = output.reshape(hc, shard_seqlen, bs, hs)
|
||||
|
||||
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
|
||||
output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
|
||||
|
||||
return output
|
||||
else:
|
||||
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
|
||||
|
||||
|
||||
class SeqAllToAll4D(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
group: dist.ProcessGroup,
|
||||
hidden_states: Tensor,
|
||||
scatter_idx: int,
|
||||
gather_idx: int,
|
||||
) -> Tensor:
|
||||
|
||||
ctx.group = group
|
||||
ctx.scatter_idx = scatter_idx
|
||||
ctx.gather_idx = gather_idx
|
||||
return all_to_all_4D(hidden_states, scatter_idx, gather_idx, group=group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
||||
return (
|
||||
None,
|
||||
SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
103
src/openmind/flow/model/sequence_parallel/seq_utils.py
Normal file
103
src/openmind/flow/model/sequence_parallel/seq_utils.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# openMind is licensed under Mulan PSL v2.
|
||||
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
# You may obtain a copy of Mulan PSL v2 at:
|
||||
#
|
||||
# http://license.coscl.org.cn/MulanPSL2
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from openmind.flow.model.sequence_parallel.ulysses import UlyssesAttention
|
||||
|
||||
_SEQUENCE_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
class DistributedTrainingModule:
|
||||
|
||||
@staticmethod
|
||||
def initialize_sequence_parallel(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
sequence_parallel_size: int = 1,
|
||||
):
|
||||
global _SEQUENCE_PARALLEL_GROUP
|
||||
|
||||
if world_size % sequence_parallel_size:
|
||||
raise ValueError(
|
||||
f"World size ({world_size}) must be devisible by sequence parallel size ({sequence_parallel_size})."
|
||||
)
|
||||
|
||||
num_sequence_parallel_groups = world_size // sequence_parallel_size
|
||||
|
||||
for i in range(num_sequence_parallel_groups):
|
||||
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
|
||||
group = dist.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_SEQUENCE_PARALLEL_GROUP = group
|
||||
|
||||
@staticmethod
|
||||
def get_sequence_parallel_world_size():
|
||||
return get_sequence_parallel_world_size()
|
||||
|
||||
@staticmethod
|
||||
def get_sequence_parallel_rank():
|
||||
return get_sequence_parallel_rank()
|
||||
|
||||
@staticmethod
|
||||
def get_sequence_parallel_group():
|
||||
return get_sequence_parallel_group()
|
||||
|
||||
|
||||
def get_sequence_parallel_world_size():
|
||||
return dist.get_world_size(group=get_sequence_parallel_group())
|
||||
|
||||
|
||||
def get_sequence_parallel_rank():
|
||||
return dist.get_world_size(group=get_sequence_parallel_rank())
|
||||
|
||||
|
||||
def get_sequence_parallel_group():
|
||||
if _SEQUENCE_PARALLEL_GROUP is None:
|
||||
raise ValueError(
|
||||
"The sequence parallel group is not initialized. Please call initialize_sequence_parallel first."
|
||||
)
|
||||
return _SEQUENCE_PARALLEL_GROUP
|
||||
|
||||
|
||||
def new_attn_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
group,
|
||||
attn_fn,
|
||||
dropout_p=0.0,
|
||||
scale=None,
|
||||
**kwargs,
|
||||
):
|
||||
dist_attn = UlyssesAttention(sequence_process_group=group, attn_fn=attn_fn)
|
||||
attn_output = dist_attn(query_states, key_states, value_states, dropout_p, scale)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def apply_sequence_parallel(args, num_head):
|
||||
if num_head % args.sequence_parallel_size:
|
||||
raise ValueError(
|
||||
"num_attention_head must be divisible by sequence_parallel_size for sequence parallel training."
|
||||
f"{num_head} can not be devisible by {args.sequence_parallel_size}"
|
||||
)
|
||||
|
||||
if args.sequence_parallel_size > 1:
|
||||
group_this = DistributedTrainingModule.get_sequence_parallel_group()
|
||||
original_attn = torch.nn.functional.scaled_dot_product_attention
|
||||
new_attention_forward = partial(new_attn_forward, group=group_this, attn_fn=original_attn)
|
||||
torch.nn.functional.scaled_dot_product_attention = new_attention_forward
|
73
src/openmind/flow/model/sequence_parallel/ulysses.py
Normal file
73
src/openmind/flow/model/sequence_parallel/ulysses.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright (c) Microsoft Corporation and Jiarui Fang
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# DeepSpeed Team & Jiarui Fang
|
||||
# modified from https://github.com/feifeibear/long-context-attention/blob/main/yunchang/ulysses/attn_layer.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from openmind.flow.model.sequence_parallel.seq_comm import SeqAllToAll4D
|
||||
|
||||
|
||||
class UlyssesAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sequence_process_group: dist.ProcessGroup = None,
|
||||
scatter_idx: int = 2,
|
||||
gather_idx: int = 1,
|
||||
attn_fn: Optional[callable] = None,
|
||||
) -> None:
|
||||
|
||||
super(UlyssesAttention, self).__init__()
|
||||
self.spg = sequence_process_group
|
||||
self.scatter_idx = scatter_idx
|
||||
self.gather_idx = gather_idx
|
||||
self.attn_fn = attn_fn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout,
|
||||
scaling,
|
||||
*args,
|
||||
):
|
||||
# (bs, head_cnt, seq_len/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
||||
|
||||
# (bs, seq_len, head_cnt/N, head_size) -> (bs, head_cnt/N, seq_len, head_size)
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
v = v.transpose(1, 2).contiguous()
|
||||
|
||||
if scaling is None:
|
||||
scaling = q.shape[-1] ** -0.5
|
||||
|
||||
context_layer = self.attn_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=dropout,
|
||||
scale=scaling,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
if isinstance(context_layer, tuple):
|
||||
context_layer = context_layer[0]
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
|
||||
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
|
||||
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
||||
output = output.transpose(1, 2).contiguous()
|
||||
return output
|
119
src/openmind/flow/train/sft/seq_utils.py
Normal file
119
src/openmind/flow/train/sft/seq_utils.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.data import Sampler
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import _is_peft_model
|
||||
from typing_extensions import override
|
||||
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.model.sequence_parallel.seq_utils import DistributedTrainingModule
|
||||
from openmind.utils.version import is_transformers_version_equal_to_4_46
|
||||
|
||||
|
||||
class SequenceParallelSampler(Sampler[int]):
|
||||
def __init__(self, data_source, per_device_bs, sequence_parallel_size, world_size):
|
||||
self.sequence_parallel_size = sequence_parallel_size
|
||||
self.per_device_bs = per_device_bs
|
||||
self.num_data = len(data_source) // self.sequence_parallel_size
|
||||
self.batch_size = self.per_device_bs * world_size // self.sequence_parallel_size
|
||||
self.num_pad_data = (
|
||||
((self.num_data // self.batch_size) + 1) * self.batch_size
|
||||
if self.num_data // self.batch_size
|
||||
else self.num_data
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
sample_indices = torch.randperm(self.num_data, generator=generator).tolist()
|
||||
sample_indices = sample_indices + sample_indices[: self.num_pad_data - self.num_data]
|
||||
|
||||
shuffle_indices = []
|
||||
group = []
|
||||
for idx, sample_idx in enumerate(sample_indices):
|
||||
group.append([sample_idx * self.sequence_parallel_size + i for i in range(self.sequence_parallel_size)])
|
||||
if idx % self.per_device_bs == self.per_device_bs - 1:
|
||||
group_for_sp = list(itertools.chain(*list(zip(*group))))
|
||||
shuffle_indices.append(group_for_sp)
|
||||
group = []
|
||||
shuffle_indices = list(itertools.chain(*shuffle_indices))
|
||||
yield from shuffle_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.num_pad_data * self.sequence_parallel_size
|
||||
|
||||
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self):
|
||||
args = get_args()
|
||||
|
||||
if args.sequence_parallel_size > 1:
|
||||
return SequenceParallelSampler(
|
||||
self.train_dataset,
|
||||
args.per_device_train_batch_size,
|
||||
args.sequence_parallel_size,
|
||||
dist.get_world_size(),
|
||||
)
|
||||
else:
|
||||
return super()._get_train_sampler()
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
args = get_args()
|
||||
|
||||
if args.sequence_parallel_size == 1:
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
else:
|
||||
_, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)
|
||||
loss_function = CrossEntropyLoss(reduction="sum")
|
||||
logits, labels = outputs["logits"] if isinstance(outputs, dict) else outputs[1], inputs["labels"]
|
||||
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
if _is_peft_model(unwrapped_model):
|
||||
vocab_size = unwrapped_model.base_model.model.config.vocab_size
|
||||
else:
|
||||
vocab_size = unwrapped_model.config.vocab_size
|
||||
|
||||
logits = logits.view(-1, vocab_size)
|
||||
labels = labels.view(-1)
|
||||
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_function(logits, labels)
|
||||
|
||||
sp_group = DistributedTrainingModule.get_sequence_parallel_group()
|
||||
loss = dist.nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=sp_group)
|
||||
label_num = (labels != loss_function.ignore_index).sum()
|
||||
label_num = dist.nn.all_reduce(label_num, op=dist.ReduceOp.SUM, group=sp_group)
|
||||
loss /= label_num
|
||||
|
||||
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
@ -13,14 +13,17 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer, TrainerCallback
|
||||
import torch.distributed as dist
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from openmind.utils import get_logger
|
||||
from openmind.utils.constants import IGNORE_INDEX
|
||||
from openmind.flow.model import get_model, get_tokenizer_and_processor
|
||||
from openmind.flow.model.sequence_parallel.seq_utils import DistributedTrainingModule
|
||||
from openmind.flow.datasets import get_template, get_dataset_module
|
||||
from openmind.flow.arguments import get_args
|
||||
from openmind.flow.datasets.collator import SFTDataCollatorWith4DAttentionMask
|
||||
from openmind.flow.train.sft.seq_utils import CustomSeq2SeqTrainer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -28,12 +31,18 @@ logger = get_logger(__name__)
|
||||
def run_sft(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
args = get_args()
|
||||
|
||||
if args.sequence_parallel_size > 1:
|
||||
DistributedTrainingModule.initialize_sequence_parallel(
|
||||
dist.get_world_size(), dist.get_rank(), args.sequence_parallel_size
|
||||
)
|
||||
|
||||
tokenizer, processor = get_tokenizer_and_processor()
|
||||
model = get_model()
|
||||
template = get_template()
|
||||
dataset_module = get_dataset_module(tokenizer, template, processor)
|
||||
|
||||
args = get_args()
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
@ -41,11 +50,14 @@ def run_sft(
|
||||
processor=processor,
|
||||
padding="max_length" if args.max_length else True,
|
||||
pad_to_multiple_of=8 if args.do_train else None,
|
||||
max_length=args.max_length,
|
||||
max_length=(
|
||||
args.max_length if args.sequence_parallel_size == 1 else args.max_length // args.sequence_parallel_size
|
||||
),
|
||||
label_pad_token_id=IGNORE_INDEX if args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
require_position_ids=args.sequence_parallel_size > 1,
|
||||
)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=args.hf_seq2seq_args,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -108,3 +108,14 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
||||
if want_ver is not None:
|
||||
for op, want_ver in wanted.items():
|
||||
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
||||
|
||||
|
||||
def get_package_version(package_name: str):
|
||||
try:
|
||||
return version.parse(importlib.metadata.version(package_name))
|
||||
except Exception:
|
||||
return version.parse("0.0.0")
|
||||
|
||||
|
||||
def is_transformers_version_equal_to_4_46():
|
||||
return version.parse("4.46.0") <= get_package_version("transformers") <= version.parse("4.46.1")
|
||||
|
Reference in New Issue
Block a user