mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 12:33:46 +08:00
Add support for NPU kernelize/layers (#155)
This change add support for Huawei Ascend NPUs. This is #146 with some formatting/typing fixes. Co-authored-by: zheliuyu <15750543867@163.com>
This commit is contained in:
@ -4,5 +4,6 @@ markers =
|
||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
||||
npu_only: marks tests that should only run on Ascend NPUs
|
||||
token: enable tests that require a write token
|
||||
is_staging_test: Marks tests that should only run on a staging environment
|
||||
|
@ -87,7 +87,7 @@ class Device:
|
||||
|
||||
Args:
|
||||
type (`str`):
|
||||
The device type (e.g., "cuda", "mps", "rocm", "xpu").
|
||||
The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu").
|
||||
properties ([`CUDAProperties`], *optional*):
|
||||
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
||||
|
||||
@ -109,6 +109,9 @@ class Device:
|
||||
|
||||
# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
|
||||
xpu_device = Device(type="xpu")
|
||||
|
||||
# NPU device (Huawei Ascend)
|
||||
npu_device = Device(type="npu")
|
||||
```
|
||||
"""
|
||||
|
||||
@ -130,6 +133,8 @@ class Device:
|
||||
return _MPSRepos()
|
||||
elif self.type == "xpu":
|
||||
return _XPURepos()
|
||||
elif self.type == "npu":
|
||||
return _NPURepos()
|
||||
else:
|
||||
raise ValueError(f"Unknown device type: {self.type}")
|
||||
|
||||
@ -472,6 +477,26 @@ class _XPURepos(_DeviceRepos):
|
||||
self._repos = repos
|
||||
|
||||
|
||||
class _NPURepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._repos = {}
|
||||
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
return self._repos
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
if device.type != "npu":
|
||||
raise ValueError(f"Device type must be 'npu', got {device.type}")
|
||||
|
||||
self._repos = repos
|
||||
|
||||
|
||||
class _MPSRepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
@ -556,7 +581,7 @@ class _ROCMRepos(_DeviceRepos):
|
||||
|
||||
def _validate_device_type(device_type: str) -> None:
|
||||
"""Validate that the device type is supported."""
|
||||
supported_devices = {"cuda", "rocm", "mps", "xpu"}
|
||||
supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"}
|
||||
if device_type not in supported_devices:
|
||||
raise ValueError(
|
||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||
@ -814,7 +839,7 @@ def kernelize(
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
|
||||
`torch.compile`.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu".
|
||||
The device type will be inferred from the model parameters when not provided.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||
@ -838,7 +863,7 @@ def kernelize(
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
mapping = {
|
||||
"LayerNorm": {
|
||||
"SiluAndMul": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
|
@ -35,6 +35,14 @@ def _get_cache_dir() -> Optional[str]:
|
||||
CACHE_DIR: Optional[str] = _get_cache_dir()
|
||||
|
||||
|
||||
def _get_privateuse_backend_name() -> Optional[str]:
|
||||
import torch
|
||||
|
||||
if hasattr(torch._C, "_get_privateuse1_backend_name"):
|
||||
return torch._C._get_privateuse1_backend_name()
|
||||
return None
|
||||
|
||||
|
||||
def build_variant() -> str:
|
||||
import torch
|
||||
|
||||
@ -49,9 +57,14 @@ def build_variant() -> str:
|
||||
elif torch.version.xpu is not None:
|
||||
version = torch.version.xpu
|
||||
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
|
||||
elif _get_privateuse_backend_name() == "npu":
|
||||
from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
|
||||
|
||||
cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
|
||||
compute_framework = f"cann{cann_major}{cann_minor}"
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
||||
"Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled."
|
||||
)
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
|
@ -3,6 +3,8 @@ import sys
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.utils import _get_privateuse_backend_name
|
||||
|
||||
has_cuda = (
|
||||
hasattr(torch.version, "cuda")
|
||||
and torch.version.cuda is not None
|
||||
@ -18,6 +20,7 @@ has_xpu = (
|
||||
and torch.version.xpu is not None
|
||||
and torch.xpu.device_count() > 0
|
||||
)
|
||||
has_npu = _get_privateuse_backend_name() == "npu"
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
@ -37,5 +40,7 @@ def pytest_runtest_setup(item):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||
if "xpu_only" in item.keywords and not has_xpu:
|
||||
pytest.skip("skipping XPU-only test on host without XPU")
|
||||
if "npu_only" in item.keywords and not has_npu:
|
||||
pytest.skip("skipping NPU-only test on host without NPU")
|
||||
if "token" in item.keywords and not item.config.getoption("--token"):
|
||||
pytest.skip("need --token option to run this test")
|
||||
|
@ -35,6 +35,7 @@ def test_load_locked():
|
||||
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_layer_locked():
|
||||
project_dir = Path(__file__).parent / "layer_locking"
|
||||
|
||||
|
@ -21,14 +21,21 @@ from kernels.layer import (
|
||||
_KERNEL_MAPPING,
|
||||
_validate_layer,
|
||||
)
|
||||
from kernels.utils import install_kernel
|
||||
from kernels.utils import (
|
||||
_get_privateuse_backend_name,
|
||||
install_kernel,
|
||||
)
|
||||
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
),
|
||||
"npu": LayerRepository(
|
||||
repo_id="kernels-ext-npu/SwiGlu",
|
||||
layer_name="SwiGlu",
|
||||
),
|
||||
},
|
||||
"SiluAndMulNoCompile": {
|
||||
"cuda": LayerRepository(
|
||||
@ -122,8 +129,10 @@ def device():
|
||||
return "cuda"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
elif _get_privateuse_backend_name() == "npu":
|
||||
return "npu"
|
||||
|
||||
pytest.skip("No CUDA or XPU")
|
||||
pytest.skip("No CUDA, NPU or XPU")
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@ -204,10 +213,33 @@ def test_hub_forward_xpu():
|
||||
assert rms_norm_with_kernel.n_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.npu_only
|
||||
def test_hub_forward_npu():
|
||||
torch.manual_seed(0)
|
||||
|
||||
silu_and_mul = SiluAndMul()
|
||||
X = torch.randn((32, 64), device="npu")
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = kernelize(
|
||||
SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE
|
||||
)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
assert silu_and_mul.n_calls == 1
|
||||
assert silu_and_mul_with_kernel.n_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
|
||||
reason="Skip on xpu devices",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
_get_privateuse_backend_name() == "npu",
|
||||
reason="Skip on npu devices",
|
||||
)
|
||||
def test_rocm_kernel_mapping():
|
||||
"""Test that ROCm shorthand device mapping works correctly."""
|
||||
kernel_layer_mapping = {
|
||||
|
Reference in New Issue
Block a user