add npu fused rmsnorm support
This commit is contained in:
@ -11,9 +11,6 @@ class Qwen3RMSNorm(nn.Module):
|
||||
variance_epsilon: float
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# 确保输入在NPU设备上
|
||||
if hidden_states.device.type != "npu":
|
||||
hidden_states = hidden_states.to("npu")
|
||||
|
||||
print("NpuRMSNorm exec!!!!!!!")
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
@ -1,4 +1,5 @@
|
||||
# tests/test_rmsnorm_npu.py
|
||||
import inspect
|
||||
import unittest
|
||||
import pathlib
|
||||
import kernels
|
||||
@ -27,14 +28,18 @@ class TestNpuRMSNorm(unittest.TestCase):
|
||||
kernel_layer_mapping = {
|
||||
"Qwen3RMSNorm": {
|
||||
"npu": kernels.LocalLayerRepository(
|
||||
repo_path=pathlib.Path("../rmsnorm_npu"),
|
||||
package_name="rmsnorm-npu",
|
||||
repo_path=pathlib.Path("/home/openmind/rmsnorm-npu"),
|
||||
package_name="rmsnorm_npu",
|
||||
layer_name="Qwen3RMSNorm",
|
||||
)
|
||||
}
|
||||
}
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm, Qwen3MLP
|
||||
target_signature = inspect.signature(Qwen3RMSNorm.forward)
|
||||
print(f"目标层签名: {target_signature}")
|
||||
print(f"Qwen3MLP签名:{inspect.signature(Qwen3MLP.forward)}")
|
||||
with use_kernel_mapping(kernel_layer_mapping, inherit_mapping=False):
|
||||
model = kernelize(self.model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, use_fallback=False)
|
||||
model = kernelize(self.model, mode=Mode.TRAINING , use_fallback=False)
|
||||
print(model.model.layers[0])
|
||||
x = torch.randn(1, 3, 2048, 2048).to('npu')
|
||||
model.model.layers[0].input_layernorm.forward(x)
|
||||
|
Reference in New Issue
Block a user