!556 [pytorch][model] patch npu fused operator to verl

Merge pull request !556 from guihaowen/master
This commit is contained in:
guihaowen
2025-08-27 07:30:01 +00:00
committed by i-robot
parent d76a536683
commit 1604cbc4a9
2 changed files with 33 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from transformers_npu.qwen2_patch import apply_qwen2_patch
apply_qwen2_patch()

View File

@ -0,0 +1,30 @@
import torch
import torch_npu
from torch_npu import npu_rotary_mul as apply_rotary_emb
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
def rms_norm_forward(self, x):
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]
def silu_forward(self, hidden_state):
return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
)
def fused_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueese_dim=1):
cos = cos.unsqueese(unsqueese_dim)
sin = sin.unsqueese(unsqueese_dim)
q_embed = torch_npu.npu_rotary_mul(q.contiguous(), cos, sin).to(q.dtype)
k_embed = torch_npu.npu_rotary_mul(k.contiguous(), cos, sin).to(k.dtype)
return q_embed, k_embed
def apply_qwen2_patch():
Qwen2MLP.forward = silu_forward
Qwen2RMSNorm.forward = rms_norm_forward
modeling_qwen2.fused_apply_rotary_pos_emb = fused_apply_rotary_pos_emb