add npu fused rmsnorm support
This commit is contained in:
3
.idea/.gitignore
generated
vendored
Normal file
3
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
6
.idea/misc.xml
generated
Normal file
6
.idea/misc.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="pt26" />
|
||||
</component>
|
||||
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/rmsnorm-npu.iml" filepath="$PROJECT_DIR$/.idea/rmsnorm-npu.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
8
.idea/rmsnorm-npu.iml
generated
Normal file
8
.idea/rmsnorm-npu.iml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="pt26" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,12 @@
|
||||
# rmsnorm_npu/__init__.py
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
|
||||
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
|
||||
from kernels import replace_kernel_forward_from_hub
|
||||
|
||||
from .layers import NpuRMSNorm
|
||||
|
||||
# 标记transformers中的RMSNorm实现为可扩展
|
||||
replace_kernel_forward_from_hub(Qwen3RMSNorm, "RMSNorm")
|
||||
replace_kernel_forward_from_hub(MistralRMSNorm, "RMSNorm") # 如果有其他RMSNorm变体
|
||||
|
||||
__all__ = ['NpuRMSNorm']
|
47
rmsnorm_npu/integration.py
Normal file
47
rmsnorm_npu/integration.py
Normal file
@ -0,0 +1,47 @@
|
||||
# rmsnorm_npu/integration.py
|
||||
from transformers import PreTrainedModel
|
||||
from kernels import kernelize, Mode
|
||||
import logging
|
||||
import torch_npu
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_npu_rmsnorm(model, mode=Mode.INFERENCE | Mode.TRAINING , use_fallback=True):
|
||||
"""
|
||||
将模型中的RMSNorm层替换为NPU优化版本
|
||||
|
||||
Args:
|
||||
model: transformers模型
|
||||
mode: 运行模式 (INFERENCE/TRAINING)
|
||||
use_fallback: 如果内核不支持,是否回退到原始实现
|
||||
|
||||
Returns:
|
||||
kernelized_model: 应用优化后的模型
|
||||
"""
|
||||
# 检查NPU可用性
|
||||
if not hasattr(torch_npu, 'npu_is_available') or not torch_npu.npu_is_available():
|
||||
logger.warning("NPU is not available. Falling back to standard RMSNorm.")
|
||||
return model
|
||||
|
||||
# 确定设备类型
|
||||
device = "npu"
|
||||
|
||||
# 应用kernelize
|
||||
logger.info(f"Applying NPU-optimized RMSNorm to model {model.__class__.__name__}")
|
||||
kernelized_model = kernelize(
|
||||
model,
|
||||
mode=mode,
|
||||
device=device,
|
||||
use_fallback=use_fallback
|
||||
)
|
||||
|
||||
# 验证替换是否成功
|
||||
from kernels import get_kernel
|
||||
replaced_layers = get_kernel(model)
|
||||
if "RMSNorm" in replaced_layers:
|
||||
logger.info(f"Successfully replaced {replaced_layers['RMSNorm']} RMSNorm layers with NPU version")
|
||||
else:
|
||||
logger.warning("No RMSNorm layers were replaced. Check if RMSNorm was properly marked as extensible.")
|
||||
|
||||
return kernelized_model
|
39
rmsnorm_npu/kernel_init.py
Normal file
39
rmsnorm_npu/kernel_init.py
Normal file
@ -0,0 +1,39 @@
|
||||
# rmsnorm_npu/kernel_init.py
|
||||
from kernels import register_kernel_mapping, LayerRepository, Device
|
||||
import logging
|
||||
|
||||
|
||||
def init_rmsnorm_npu():
|
||||
"""注册NPU RMSNorm内核映射"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建内核映射
|
||||
kernel_layer_mapping = {
|
||||
"RMSNorm": {
|
||||
Device(type="npu"): LayerRepository(
|
||||
repo_id="your-username/rmsnorm-npu", # 替换为您的HF Hub仓库ID
|
||||
layer_name="NpuRMSNorm",
|
||||
version=">=0.0.1,<0.1.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
kernel_layer_mapping = {
|
||||
"Qwen3RMSNorm": {
|
||||
"npu": LayerRepository(
|
||||
repo_id="https://github.com/frozenleaves/rmsnorm-npu",
|
||||
layer_name="RmsNorm",
|
||||
revision="layers",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# 注册映射
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
logger.info("Registered NPU RMSNorm kernel mapping")
|
||||
|
||||
return kernel_layer_mapping
|
||||
|
||||
|
||||
# 可选:自动初始化
|
||||
if __name__ != "__main__":
|
||||
init_rmsnorm_npu()
|
@ -0,0 +1,28 @@
|
||||
# rmsnorm_npu/layers.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from kernels import use_kernel_forward_from_hub
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class NpuRMSNorm(nn.Module):
|
||||
"""
|
||||
RMSNorm operator optimized for NPU. When NPU is available, it will replace the standard RMSNorm implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# 确保输入在NPU设备上
|
||||
if hidden_states.device.type != "npu":
|
||||
hidden_states = hidden_states.to("npu")
|
||||
|
||||
# 执行NPU优化的RMSNorm
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
25
setup.py
25
setup.py
@ -0,0 +1,25 @@
|
||||
# setup.py
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="rmsnorm-npu",
|
||||
version="0.0.1",
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
"torch>=2.5.1",
|
||||
"torch_npu>=2.5.1", # 确保使用正确的torch_npu版本
|
||||
"transformers>=4.51.3",
|
||||
"kernels", # 假设这是提供hub功能的库
|
||||
],
|
||||
description="NPU optimized RMSNorm kernel for transformers",
|
||||
author="Your Name",
|
||||
author_email="your.email@example.com",
|
||||
url="https://github.com/frozenleaves/rmsnorm-npu",
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
78
tests/test_rmsnorm_npu.py
Normal file
78
tests/test_rmsnorm_npu.py
Normal file
@ -0,0 +1,78 @@
|
||||
# tests/test_rmsnorm_npu.py
|
||||
import unittest
|
||||
import logging
|
||||
import torch
|
||||
import torch_npu
|
||||
from transformers import AutoModel
|
||||
from rmsnorm_npu.integration import apply_npu_rmsnorm
|
||||
from kernels import kernelize, Mode
|
||||
|
||||
|
||||
class TestNpuRMSNorm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "meta-llama/Llama-2-7b-hf" # 或其他可用的小型模型
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
def test_npu_rmsnorm_integration(self):
|
||||
"""测试NPU RMSNorm是否成功集成"""
|
||||
if not torch_npu.npu_is_available():
|
||||
self.skipTest("NPU is not available")
|
||||
|
||||
# 配置日志捕获
|
||||
log_output = []
|
||||
log_handler = logging.StreamHandler()
|
||||
log_handler.setFormatter(logging.Formatter('%(message)s'))
|
||||
log_handler.emit = lambda record: log_output.append(record.getMessage())
|
||||
|
||||
# 获取kernels模块的日志器
|
||||
logger = logging.getLogger("kernels")
|
||||
logger.addHandler(log_handler)
|
||||
original_level = logger.level
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
try:
|
||||
# 应用NPU RMSNorm
|
||||
kernelized_model = apply_npu_rmsnorm(self.model, mode=Mode.TRAINING)
|
||||
|
||||
# 检查日志中是否有RMSNorm被替换的信息
|
||||
rmsnorm_replaced = False
|
||||
for message in log_output:
|
||||
if "RMSNorm" in message and ("replaced" in message or "using" in message):
|
||||
rmsnorm_replaced = True
|
||||
print(f"Found replacement message: {message}")
|
||||
break
|
||||
|
||||
self.assertTrue(rmsnorm_replaced, "RMSNorm layers were not replaced with NPU version")
|
||||
|
||||
# 额外验证:检查模型中是否包含NpuRMSNorm实例
|
||||
from rmsnorm_npu.layers import NpuRMSNorm
|
||||
|
||||
npu_rmsnorm_count = 0
|
||||
for name, module in kernelized_model.named_modules():
|
||||
if isinstance(module, NpuRMSNorm):
|
||||
npu_rmsnorm_count += 1
|
||||
print(f"Found NpuRMSNorm at: {name}")
|
||||
|
||||
self.assertGreater(npu_rmsnorm_count, 0, "No NpuRMSNorm instances found in the model")
|
||||
|
||||
# 测试前向传播
|
||||
input_ids = torch.randint(0, 1000, (1, 100)).to("npu")
|
||||
with torch.no_grad():
|
||||
outputs = kernelized_model(input_ids)
|
||||
|
||||
self.assertIsNotNone(outputs)
|
||||
self.assertTrue(torch.isfinite(outputs.logits).all())
|
||||
|
||||
finally:
|
||||
# 恢复日志级别
|
||||
logger.removeHandler(log_handler)
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user