add npu fused rmsnorm support

This commit is contained in:
frozenleaves
2025-09-23 19:25:02 +08:00
parent 3aba737a3f
commit 4f3b4f6603
13 changed files with 266 additions and 0 deletions

3
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,3 @@
# 默认忽略的文件
/shelf/
/workspace.xml

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

6
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="pt26" />
</component>
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/rmsnorm-npu.iml" filepath="$PROJECT_DIR$/.idea/rmsnorm-npu.iml" />
</modules>
</component>
</project>

8
.idea/rmsnorm-npu.iml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="pt26" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

View File

@ -0,0 +1,12 @@
# rmsnorm_npu/__init__.py
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from kernels import replace_kernel_forward_from_hub
from .layers import NpuRMSNorm
# 标记transformers中的RMSNorm实现为可扩展
replace_kernel_forward_from_hub(Qwen3RMSNorm, "RMSNorm")
replace_kernel_forward_from_hub(MistralRMSNorm, "RMSNorm") # 如果有其他RMSNorm变体
__all__ = ['NpuRMSNorm']

View File

@ -0,0 +1,47 @@
# rmsnorm_npu/integration.py
from transformers import PreTrainedModel
from kernels import kernelize, Mode
import logging
import torch_npu
logger = logging.getLogger(__name__)
def apply_npu_rmsnorm(model, mode=Mode.INFERENCE | Mode.TRAINING , use_fallback=True):
"""
将模型中的RMSNorm层替换为NPU优化版本
Args:
model: transformers模型
mode: 运行模式 (INFERENCE/TRAINING)
use_fallback: 如果内核不支持,是否回退到原始实现
Returns:
kernelized_model: 应用优化后的模型
"""
# 检查NPU可用性
if not hasattr(torch_npu, 'npu_is_available') or not torch_npu.npu_is_available():
logger.warning("NPU is not available. Falling back to standard RMSNorm.")
return model
# 确定设备类型
device = "npu"
# 应用kernelize
logger.info(f"Applying NPU-optimized RMSNorm to model {model.__class__.__name__}")
kernelized_model = kernelize(
model,
mode=mode,
device=device,
use_fallback=use_fallback
)
# 验证替换是否成功
from kernels import get_kernel
replaced_layers = get_kernel(model)
if "RMSNorm" in replaced_layers:
logger.info(f"Successfully replaced {replaced_layers['RMSNorm']} RMSNorm layers with NPU version")
else:
logger.warning("No RMSNorm layers were replaced. Check if RMSNorm was properly marked as extensible.")
return kernelized_model

View File

@ -0,0 +1,39 @@
# rmsnorm_npu/kernel_init.py
from kernels import register_kernel_mapping, LayerRepository, Device
import logging
def init_rmsnorm_npu():
"""注册NPU RMSNorm内核映射"""
logger = logging.getLogger(__name__)
# 创建内核映射
kernel_layer_mapping = {
"RMSNorm": {
Device(type="npu"): LayerRepository(
repo_id="your-username/rmsnorm-npu", # 替换为您的HF Hub仓库ID
layer_name="NpuRMSNorm",
version=">=0.0.1,<0.1.0",
)
}
}
kernel_layer_mapping = {
"Qwen3RMSNorm": {
"npu": LayerRepository(
repo_id="https://github.com/frozenleaves/rmsnorm-npu",
layer_name="RmsNorm",
revision="layers",
),
},
}
# 注册映射
register_kernel_mapping(kernel_layer_mapping)
logger.info("Registered NPU RMSNorm kernel mapping")
return kernel_layer_mapping
# 可选:自动初始化
if __name__ != "__main__":
init_rmsnorm_npu()

View File

@ -0,0 +1,28 @@
# rmsnorm_npu/layers.py
import torch
import torch.nn as nn
import torch_npu
from kernels import use_kernel_forward_from_hub
@use_kernel_forward_from_hub("RMSNorm")
class NpuRMSNorm(nn.Module):
"""
RMSNorm operator optimized for NPU. When NPU is available, it will replace the standard RMSNorm implementation.
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# 确保输入在NPU设备上
if hidden_states.device.type != "npu":
hidden_states = hidden_states.to("npu")
# 执行NPU优化的RMSNorm
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@ -0,0 +1,25 @@
# setup.py
from setuptools import setup, find_packages
setup(
name="rmsnorm-npu",
version="0.0.1",
packages=find_packages(),
install_requires=[
"torch>=2.5.1",
"torch_npu>=2.5.1", # 确保使用正确的torch_npu版本
"transformers>=4.51.3",
"kernels", # 假设这是提供hub功能的库
],
description="NPU optimized RMSNorm kernel for transformers",
author="Your Name",
author_email="your.email@example.com",
url="https://github.com/frozenleaves/rmsnorm-npu",
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)

0
tests/__init__.py Normal file
View File

78
tests/test_rmsnorm_npu.py Normal file
View File

@ -0,0 +1,78 @@
# tests/test_rmsnorm_npu.py
import unittest
import logging
import torch
import torch_npu
from transformers import AutoModel
from rmsnorm_npu.integration import apply_npu_rmsnorm
from kernels import kernelize, Mode
class TestNpuRMSNorm(unittest.TestCase):
def setUp(self):
self.model_id = "meta-llama/Llama-2-7b-hf" # 或其他可用的小型模型
self.model = AutoModel.from_pretrained(
self.model_id,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
def test_npu_rmsnorm_integration(self):
"""测试NPU RMSNorm是否成功集成"""
if not torch_npu.npu_is_available():
self.skipTest("NPU is not available")
# 配置日志捕获
log_output = []
log_handler = logging.StreamHandler()
log_handler.setFormatter(logging.Formatter('%(message)s'))
log_handler.emit = lambda record: log_output.append(record.getMessage())
# 获取kernels模块的日志器
logger = logging.getLogger("kernels")
logger.addHandler(log_handler)
original_level = logger.level
logger.setLevel(logging.INFO)
try:
# 应用NPU RMSNorm
kernelized_model = apply_npu_rmsnorm(self.model, mode=Mode.TRAINING)
# 检查日志中是否有RMSNorm被替换的信息
rmsnorm_replaced = False
for message in log_output:
if "RMSNorm" in message and ("replaced" in message or "using" in message):
rmsnorm_replaced = True
print(f"Found replacement message: {message}")
break
self.assertTrue(rmsnorm_replaced, "RMSNorm layers were not replaced with NPU version")
# 额外验证检查模型中是否包含NpuRMSNorm实例
from rmsnorm_npu.layers import NpuRMSNorm
npu_rmsnorm_count = 0
for name, module in kernelized_model.named_modules():
if isinstance(module, NpuRMSNorm):
npu_rmsnorm_count += 1
print(f"Found NpuRMSNorm at: {name}")
self.assertGreater(npu_rmsnorm_count, 0, "No NpuRMSNorm instances found in the model")
# 测试前向传播
input_ids = torch.randint(0, 1000, (1, 100)).to("npu")
with torch.no_grad():
outputs = kernelized_model(input_ids)
self.assertIsNotNone(outputs)
self.assertTrue(torch.isfinite(outputs.logits).all())
finally:
# 恢复日志级别
logger.removeHandler(log_handler)
logger.setLevel(original_level)
if __name__ == "__main__":
unittest.main()