add npu fused rmsnorm support

This commit is contained in:
frozenleaves
2025-09-24 16:23:40 +08:00
parent a746f4f55d
commit 8350b52da6
2 changed files with 8 additions and 6 deletions

View File

@ -11,9 +11,6 @@ class Qwen3RMSNorm(nn.Module):
variance_epsilon: float
def forward(self, hidden_states):
# 确保输入在NPU设备上
if hidden_states.device.type != "npu":
hidden_states = hidden_states.to("npu")
print("NpuRMSNorm exec!!!!!!!")
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]

View File

@ -1,4 +1,5 @@
# tests/test_rmsnorm_npu.py
import inspect
import unittest
import pathlib
import kernels
@ -27,14 +28,18 @@ class TestNpuRMSNorm(unittest.TestCase):
kernel_layer_mapping = {
"Qwen3RMSNorm": {
"npu": kernels.LocalLayerRepository(
repo_path=pathlib.Path("../rmsnorm_npu"),
package_name="rmsnorm-npu",
repo_path=pathlib.Path("/home/openmind/rmsnorm-npu"),
package_name="rmsnorm_npu",
layer_name="Qwen3RMSNorm",
)
}
}
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm, Qwen3MLP
target_signature = inspect.signature(Qwen3RMSNorm.forward)
print(f"目标层签名: {target_signature}")
print(f"Qwen3MLP签名{inspect.signature(Qwen3MLP.forward)}")
with use_kernel_mapping(kernel_layer_mapping, inherit_mapping=False):
model = kernelize(self.model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, use_fallback=False)
model = kernelize(self.model, mode=Mode.TRAINING , use_fallback=False)
print(model.model.layers[0])
x = torch.randn(1, 3, 2048, 2048).to('npu')
model.model.layers[0].input_layernorm.forward(x)