Add support for NPU kernelize/layers (#155)

This change add support for Huawei Ascend NPUs. This is #146 with some formatting/typing fixes.

Co-authored-by: zheliuyu <15750543867@163.com>
This commit is contained in:
Daniël de Kok
2025-09-23 10:46:41 +02:00
committed by GitHub
parent dfee307d54
commit fb8cd99a2c
6 changed files with 85 additions and 8 deletions

View File

@ -4,5 +4,6 @@ markers =
rocm_only: marks tests that should only run on hosts with ROCm GPUs
darwin_only: marks tests that should only run on macOS
xpu_only: marks tests that should only run on hosts with Intel XPUs
npu_only: marks tests that should only run on Ascend NPUs
token: enable tests that require a write token
is_staging_test: Marks tests that should only run on a staging environment

View File

@ -87,7 +87,7 @@ class Device:
Args:
type (`str`):
The device type (e.g., "cuda", "mps", "rocm", "xpu").
The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu").
properties ([`CUDAProperties`], *optional*):
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
@ -109,6 +109,9 @@ class Device:
# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
xpu_device = Device(type="xpu")
# NPU device (Huawei Ascend)
npu_device = Device(type="npu")
```
"""
@ -130,6 +133,8 @@ class Device:
return _MPSRepos()
elif self.type == "xpu":
return _XPURepos()
elif self.type == "npu":
return _NPURepos()
else:
raise ValueError(f"Unknown device type: {self.type}")
@ -472,6 +477,26 @@ class _XPURepos(_DeviceRepos):
self._repos = repos
class _NPURepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepositoryProtocol]
def __init__(self):
super().__init__()
self._repos = {}
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
return self._repos
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
if device.type != "npu":
raise ValueError(f"Device type must be 'npu', got {device.type}")
self._repos = repos
class _MPSRepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepositoryProtocol]
@ -556,7 +581,7 @@ class _ROCMRepos(_DeviceRepos):
def _validate_device_type(device_type: str) -> None:
"""Validate that the device type is supported."""
supported_devices = {"cuda", "rocm", "mps", "xpu"}
supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"}
if device_type not in supported_devices:
raise ValueError(
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
@ -814,7 +839,7 @@ def kernelize(
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
`torch.compile`.
device (`Union[str, torch.device]`, *optional*):
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu".
The device type will be inferred from the model parameters when not provided.
use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the original forward method of modules when no compatible kernel could be found.
@ -838,7 +863,7 @@ def kernelize(
return F.silu(x[..., :d]) * x[..., d:]
mapping = {
"LayerNorm": {
"SiluAndMul": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",

View File

@ -35,6 +35,14 @@ def _get_cache_dir() -> Optional[str]:
CACHE_DIR: Optional[str] = _get_cache_dir()
def _get_privateuse_backend_name() -> Optional[str]:
import torch
if hasattr(torch._C, "_get_privateuse1_backend_name"):
return torch._C._get_privateuse1_backend_name()
return None
def build_variant() -> str:
import torch
@ -49,9 +57,14 @@ def build_variant() -> str:
elif torch.version.xpu is not None:
version = torch.version.xpu
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
elif _get_privateuse_backend_name() == "npu":
from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
compute_framework = f"cann{cann_major}{cann_minor}"
else:
raise AssertionError(
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
"Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled."
)
torch_version = parse(torch.__version__)

View File

@ -3,6 +3,8 @@ import sys
import pytest
import torch
from kernels.utils import _get_privateuse_backend_name
has_cuda = (
hasattr(torch.version, "cuda")
and torch.version.cuda is not None
@ -18,6 +20,7 @@ has_xpu = (
and torch.version.xpu is not None
and torch.xpu.device_count() > 0
)
has_npu = _get_privateuse_backend_name() == "npu"
def pytest_addoption(parser):
@ -37,5 +40,7 @@ def pytest_runtest_setup(item):
pytest.skip("skipping macOS-only test on non-macOS platform")
if "xpu_only" in item.keywords and not has_xpu:
pytest.skip("skipping XPU-only test on host without XPU")
if "npu_only" in item.keywords and not has_npu:
pytest.skip("skipping NPU-only test on host without NPU")
if "token" in item.keywords and not item.config.getoption("--token"):
pytest.skip("need --token option to run this test")

View File

@ -35,6 +35,7 @@ def test_load_locked():
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
@pytest.mark.cuda_only
def test_layer_locked():
project_dir = Path(__file__).parent / "layer_locking"

View File

@ -21,14 +21,21 @@ from kernels.layer import (
_KERNEL_MAPPING,
_validate_layer,
)
from kernels.utils import install_kernel
from kernels.utils import (
_get_privateuse_backend_name,
install_kernel,
)
kernel_layer_mapping = {
"SiluAndMul": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
),
"npu": LayerRepository(
repo_id="kernels-ext-npu/SwiGlu",
layer_name="SwiGlu",
),
},
"SiluAndMulNoCompile": {
"cuda": LayerRepository(
@ -122,8 +129,10 @@ def device():
return "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
elif _get_privateuse_backend_name() == "npu":
return "npu"
pytest.skip("No CUDA or XPU")
pytest.skip("No CUDA, NPU or XPU")
def test_arg_kinds():
@ -204,10 +213,33 @@ def test_hub_forward_xpu():
assert rms_norm_with_kernel.n_calls == 0
@pytest.mark.npu_only
def test_hub_forward_npu():
torch.manual_seed(0)
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device="npu")
Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(
SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE
)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert silu_and_mul.n_calls == 1
assert silu_and_mul_with_kernel.n_calls == 0
@pytest.mark.skipif(
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
reason="Skip on xpu devices",
)
@pytest.mark.skipif(
_get_privateuse_backend_name() == "npu",
reason="Skip on npu devices",
)
def test_rocm_kernel_mapping():
"""Test that ROCm shorthand device mapping works correctly."""
kernel_layer_mapping = {