mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
Compare commits
4 Commits
e8a66bdd82
...
1604cbc4a9
Author | SHA1 | Date | |
---|---|---|---|
1604cbc4a9 | |||
d76a536683 | |||
114cb64341 | |||
7e2a9d348a |
@ -235,6 +235,7 @@ MindSpeed RL是基于昇腾生态的强化学习加速框架,旨在为华为 [
|
||||
GRPO <br>
|
||||
DAPO <br>
|
||||
PPO <br>
|
||||
DPO <br>
|
||||
</td>
|
||||
</td>
|
||||
<td> Preview</td>
|
||||
|
@ -34,8 +34,6 @@ megatron_training:
|
||||
swap_optimizer: true
|
||||
moe_alltoall_overlap_comm: true
|
||||
reset_position_ids: true
|
||||
use_ascend_coc: true
|
||||
coc_fused_kernel: true
|
||||
|
||||
actor_config:
|
||||
model: deepseekv3_671b
|
||||
|
@ -76,6 +76,7 @@ megatron_training:
|
||||
```
|
||||
|
||||
## 性能数据
|
||||
| 模型 | 机器型号 | GBS | 集群 | 方案 | 序列 | 性能 |
|
||||
|---|---|---|---|---|---|---|
|
||||
| Qwen3-30B-A3B | Atlas A2 | 64 | 2x8 | 全参 | dynamic | 2.78 samples/s |
|
||||
| 模型 | 机器型号 | GBS | 集群 | 方案 | 序列 | 性能 |
|
||||
|---|----------|---|---|---|---|----------------|
|
||||
| Qwen3-30B-A3B | Atlas A2 | 64 | 2x8 | 全参 | dynamic | 2.78 samples/s |
|
||||
| Qwen3-30B-A3B | Atlas A3 | 64 | 2x8 | 全参 | dynamic | 7.19 samples/s |
|
@ -41,6 +41,12 @@
|
||||
```
|
||||
guarantee_order: true
|
||||
```
|
||||
- 1.3 DAPO场景
|
||||
|
||||
DAPO场景下使能确定性计算,应配置动态采样参数为false,才能保证每轮迭代的数据输入是一致的。
|
||||
```
|
||||
filter_groups_enable: true
|
||||
```
|
||||
|
||||
### 2. 使能确定性计算参数
|
||||
- 2.1 使能算子API确定性计算
|
||||
|
@ -53,3 +53,6 @@ wandb可视化训练指标效果示例:
|
||||

|
||||
|
||||

|
||||
|
||||
注:
|
||||
>1.qwen3-30b dpo暂不支持TensorBoard,你可以使用WandB取代!
|
||||
|
@ -169,7 +169,7 @@ cd ..
|
||||
|
||||
git clone https://gitee.com/ascend/MindSpeed-LLM.git -b 2.1.0
|
||||
cd MindSpeed-LLM
|
||||
git checkout df21738b7234240c704c3ba232daf069eeed57b4
|
||||
git checkout 180ef141
|
||||
cp -r mindspeed_llm ../MindSpeed-RL/
|
||||
cd ..
|
||||
|
||||
|
@ -69,7 +69,7 @@ class GenerateConfig(BaseConfig):
|
||||
# 模型的最大长度(以 token 为单位),默认为 2048
|
||||
self.max_model_len = 2048
|
||||
|
||||
self.max_num_batched_tokens = 8192
|
||||
self.max_num_batched_tokens = 2048
|
||||
|
||||
# 模型权重的数据类型,默认为 bfloat16
|
||||
self.dtype = "bfloat16"
|
||||
|
@ -48,7 +48,7 @@ class VLLMInferEngine(BaseInferEngine):
|
||||
num_scheduler_steps: int = 1,
|
||||
max_num_seqs: int = 1,
|
||||
max_model_len: int = 2048,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
max_num_batched_tokens: int = 2048,
|
||||
dtype: str = "bfloat16",
|
||||
gpu_memory_utilization: float = 0.5,
|
||||
trust_remote_code: bool = True,
|
||||
|
3
rl-plugin/transformer_npu/__init__.py
Normal file
3
rl-plugin/transformer_npu/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from transformers_npu.qwen2_patch import apply_qwen2_patch
|
||||
|
||||
apply_qwen2_patch()
|
30
rl-plugin/transformer_npu/qwen2_patch.py
Normal file
30
rl-plugin/transformer_npu/qwen2_patch.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch_npu import npu_rotary_mul as apply_rotary_emb
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
|
||||
|
||||
|
||||
def rms_norm_forward(self, x):
|
||||
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
|
||||
def silu_forward(self, hidden_state):
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||
)
|
||||
|
||||
|
||||
def fused_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueese_dim=1):
|
||||
cos = cos.unsqueese(unsqueese_dim)
|
||||
sin = sin.unsqueese(unsqueese_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q.contiguous(), cos, sin).to(q.dtype)
|
||||
k_embed = torch_npu.npu_rotary_mul(k.contiguous(), cos, sin).to(k.dtype)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def apply_qwen2_patch():
|
||||
Qwen2MLP.forward = silu_forward
|
||||
Qwen2RMSNorm.forward = rms_norm_forward
|
||||
modeling_qwen2.fused_apply_rotary_pos_emb = fused_apply_rotary_pos_emb
|
@ -333,6 +333,7 @@ class TestActor():
|
||||
sampling_config=sampling_config,
|
||||
max_num_seqs=16,
|
||||
max_model_len=4096,
|
||||
max_num_batched_tokens=8192,
|
||||
dtype="bfloat16",
|
||||
gpu_memory_utilization=0.6,
|
||||
trust_remote_code=True,
|
||||
|
Reference in New Issue
Block a user