Add support for XPU layer repostories (#142)

This change adds support for XPU layer repositories, e.g.:

```
kernel_mapping = {
    "LigerRMSNorm": {
        "xpu": LayerRepository(
            repo_id="kernels-community/liger_kernels",
            layer_name="LigerRMSNorm",
        )
    },
}

Co-authored-by: YangKai0616 <kai.yang@intel.com>
This commit is contained in:
Daniël de Kok
2025-09-11 15:51:02 +02:00
committed by GitHub
parent 07e5e8481a
commit d383fdd4b4
4 changed files with 119 additions and 19 deletions

View File

@ -3,3 +3,4 @@ markers =
cuda_only: marks tests that should only hosts with CUDA GPUs
rocm_only: marks tests that should only run on hosts with ROCm GPUs
darwin_only: marks tests that should only run on macOS
xpu_only: marks tests that should only run on hosts with Intel XPUs

View File

@ -87,7 +87,7 @@ class Device:
Args:
type (`str`):
The device type (e.g., "cuda", "mps", "rocm").
The device type (e.g., "cuda", "mps", "rocm", "xpu").
properties ([`CUDAProperties`], *optional*):
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
@ -106,6 +106,9 @@ class Device:
# MPS device for Apple Silicon
mps_device = Device(type="mps")
# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
xpu_device = Device(type="xpu")
```
"""
@ -125,6 +128,8 @@ class Device:
return _ROCMRepos()
elif self.type == "mps":
return _MPSRepos()
elif self.type == "xpu":
return _XPURepos()
else:
raise ValueError(f"Unknown device type: {self.type}")
@ -447,6 +452,26 @@ class _DeviceRepos(ABC):
...
class _XPURepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepositoryProtocol]
def __init__(self):
super().__init__()
self._repos = {}
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
return self._repos
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
if device.type != "xpu":
raise ValueError(f"Device type must be 'xpu', got {device.type}")
self._repos = repos
class _MPSRepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepositoryProtocol]
@ -531,7 +556,7 @@ class _ROCMRepos(_DeviceRepos):
def _validate_device_type(device_type: str) -> None:
"""Validate that the device type is supported."""
supported_devices = {"cuda", "rocm", "mps"}
supported_devices = {"cuda", "rocm", "mps", "xpu"}
if device_type not in supported_devices:
raise ValueError(
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
@ -789,7 +814,7 @@ def kernelize(
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
`torch.compile`.
device (`Union[str, torch.device]`, *optional*):
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
The device type will be inferred from the model parameters when not provided.
use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the original forward method of modules when no compatible kernel could be found.

View File

@ -13,6 +13,11 @@ has_rocm = (
and torch.version.hip is not None
and torch.cuda.device_count() > 0
)
has_xpu = (
hasattr(torch.version, "xpu")
and torch.version.xpu is not None
and torch.xpu.device_count() > 0
)
def pytest_runtest_setup(item):
@ -22,3 +27,5 @@ def pytest_runtest_setup(item):
pytest.skip("skipping ROCm-only test on host without ROCm")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")
if "xpu_only" in item.keywords and not has_xpu:
pytest.skip("skipping XPU-only test on host without XPU")

View File

@ -46,11 +46,37 @@ kernel_layer_mapping = {
layer_name="SiluAndMul",
)
},
"LigerRMSNorm": {
"xpu": LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm", # Triton
)
},
}
register_kernel_mapping(kernel_layer_mapping)
class RMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
super().__init__()
# Used to check that we called hub kernel.
self.n_calls = 0
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, x: torch.Tensor):
self.n_calls += 1
var = x.pow(2).mean(-1, keepdim=True)
x_norm = x * torch.rsqrt(var + self.variance_epsilon)
return x_norm * self.weight
@use_kernel_forward_from_hub("LigerRMSNorm")
class RMSNormWithKernel(RMSNorm):
pass
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
@ -90,6 +116,16 @@ class TorchLinearWithCounter(nn.Linear):
return super().forward(input)
@pytest.fixture
def device():
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
pytest.skip("No CUDA or XPU")
def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
@ -147,6 +183,31 @@ def test_hub_forward_rocm():
assert silu_and_mul_with_kernel.n_calls in [0, 1]
@pytest.mark.xpu_only
def test_hub_forward_xpu():
torch.manual_seed(0)
hidden_size = 1024
weight = torch.ones(hidden_size, device="xpu")
rms_norm = RMSNorm(weight).to("xpu")
X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32)
Y = rms_norm(X)
rms_norm_with_kernel = kernelize(
RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu"
)
Y_kernel = rms_norm_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
assert rms_norm.n_calls == 1
assert rms_norm_with_kernel.n_calls == 0
@pytest.mark.skipif(
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
reason="Skip on xpu devices",
)
def test_rocm_kernel_mapping():
"""Test that ROCm shorthand device mapping works correctly."""
kernel_layer_mapping = {
@ -234,16 +295,16 @@ def test_layer_fallback_works():
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
def test_local_layer_repo():
def test_local_layer_repo(device):
# Fetch a kernel to the local cache.
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
linear = TorchLinearWithCounter(32, 32).to("cuda")
linear = TorchLinearWithCounter(32, 32).to(device)
with use_kernel_mapping(
{
"Linear": {
"cuda": LocalLayerRepository(
device: LocalLayerRepository(
# install_kernel will give the fully-resolved path.
repo_path=path.parent.parent,
package_name=package_name,
@ -255,7 +316,7 @@ def test_local_layer_repo():
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
X = torch.randn(10, 32, device=device)
linear(X)
assert linear.n_calls == 0
@ -323,6 +384,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
}
extra_mapping1 = {
@ -340,6 +402,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}
@ -358,6 +421,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}
assert (
@ -371,6 +435,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}
assert (
@ -393,6 +458,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}
assert (
@ -404,6 +470,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
}
@ -923,7 +990,7 @@ def test_kernel_modes_cross_fallback():
assert linear.n_calls == 2
def test_layer_versions():
def test_layer_versions(device):
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
@ -934,20 +1001,20 @@ def test_layer_versions():
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<1.0.0",
@ -955,13 +1022,13 @@ def test_layer_versions():
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<0.2.0",
@ -969,13 +1036,13 @@ def test_layer_versions():
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.1.0,<0.2.0",
@ -983,13 +1050,13 @@ def test_layer_versions():
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.2.0",
@ -998,13 +1065,13 @@ def test_layer_versions():
}
):
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
kernelize(version, device="cuda", mode=Mode.INFERENCE)
kernelize(version, device=device, mode=Mode.INFERENCE)
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
revision="v0.1.0",