mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 12:33:46 +08:00
Add support for XPU layer repostories (#142)
This change adds support for XPU layer repositories, e.g.: ``` kernel_mapping = { "LigerRMSNorm": { "xpu": LayerRepository( repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm", ) }, } Co-authored-by: YangKai0616 <kai.yang@intel.com>
This commit is contained in:
@ -3,3 +3,4 @@ markers =
|
||||
cuda_only: marks tests that should only hosts with CUDA GPUs
|
||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
||||
|
@ -87,7 +87,7 @@ class Device:
|
||||
|
||||
Args:
|
||||
type (`str`):
|
||||
The device type (e.g., "cuda", "mps", "rocm").
|
||||
The device type (e.g., "cuda", "mps", "rocm", "xpu").
|
||||
properties ([`CUDAProperties`], *optional*):
|
||||
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
||||
|
||||
@ -106,6 +106,9 @@ class Device:
|
||||
|
||||
# MPS device for Apple Silicon
|
||||
mps_device = Device(type="mps")
|
||||
|
||||
# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
|
||||
xpu_device = Device(type="xpu")
|
||||
```
|
||||
"""
|
||||
|
||||
@ -125,6 +128,8 @@ class Device:
|
||||
return _ROCMRepos()
|
||||
elif self.type == "mps":
|
||||
return _MPSRepos()
|
||||
elif self.type == "xpu":
|
||||
return _XPURepos()
|
||||
else:
|
||||
raise ValueError(f"Unknown device type: {self.type}")
|
||||
|
||||
@ -447,6 +452,26 @@ class _DeviceRepos(ABC):
|
||||
...
|
||||
|
||||
|
||||
class _XPURepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._repos = {}
|
||||
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
return self._repos
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
if device.type != "xpu":
|
||||
raise ValueError(f"Device type must be 'xpu', got {device.type}")
|
||||
|
||||
self._repos = repos
|
||||
|
||||
|
||||
class _MPSRepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
@ -531,7 +556,7 @@ class _ROCMRepos(_DeviceRepos):
|
||||
|
||||
def _validate_device_type(device_type: str) -> None:
|
||||
"""Validate that the device type is supported."""
|
||||
supported_devices = {"cuda", "rocm", "mps"}
|
||||
supported_devices = {"cuda", "rocm", "mps", "xpu"}
|
||||
if device_type not in supported_devices:
|
||||
raise ValueError(
|
||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||
@ -789,7 +814,7 @@ def kernelize(
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
|
||||
`torch.compile`.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
|
||||
The device type will be inferred from the model parameters when not provided.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||
|
@ -13,6 +13,11 @@ has_rocm = (
|
||||
and torch.version.hip is not None
|
||||
and torch.cuda.device_count() > 0
|
||||
)
|
||||
has_xpu = (
|
||||
hasattr(torch.version, "xpu")
|
||||
and torch.version.xpu is not None
|
||||
and torch.xpu.device_count() > 0
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
@ -22,3 +27,5 @@ def pytest_runtest_setup(item):
|
||||
pytest.skip("skipping ROCm-only test on host without ROCm")
|
||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||
if "xpu_only" in item.keywords and not has_xpu:
|
||||
pytest.skip("skipping XPU-only test on host without XPU")
|
||||
|
@ -46,11 +46,37 @@ kernel_layer_mapping = {
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
},
|
||||
"LigerRMSNorm": {
|
||||
"xpu": LayerRepository(
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm", # Triton
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.n_calls += 1
|
||||
var = x.pow(2).mean(-1, keepdim=True)
|
||||
x_norm = x * torch.rsqrt(var + self.variance_epsilon)
|
||||
return x_norm * self.weight
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("LigerRMSNorm")
|
||||
class RMSNormWithKernel(RMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -90,6 +116,16 @@ class TorchLinearWithCounter(nn.Linear):
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
|
||||
pytest.skip("No CUDA or XPU")
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@use_kernel_forward_from_hub("ArgKind")
|
||||
class ArgKind(nn.Module):
|
||||
@ -147,6 +183,31 @@ def test_hub_forward_rocm():
|
||||
assert silu_and_mul_with_kernel.n_calls in [0, 1]
|
||||
|
||||
|
||||
@pytest.mark.xpu_only
|
||||
def test_hub_forward_xpu():
|
||||
torch.manual_seed(0)
|
||||
|
||||
hidden_size = 1024
|
||||
weight = torch.ones(hidden_size, device="xpu")
|
||||
rms_norm = RMSNorm(weight).to("xpu")
|
||||
X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32)
|
||||
Y = rms_norm(X)
|
||||
|
||||
rms_norm_with_kernel = kernelize(
|
||||
RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu"
|
||||
)
|
||||
Y_kernel = rms_norm_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
assert rms_norm.n_calls == 1
|
||||
assert rms_norm_with_kernel.n_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
|
||||
reason="Skip on xpu devices",
|
||||
)
|
||||
def test_rocm_kernel_mapping():
|
||||
"""Test that ROCm shorthand device mapping works correctly."""
|
||||
kernel_layer_mapping = {
|
||||
@ -234,16 +295,16 @@ def test_layer_fallback_works():
|
||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||
|
||||
|
||||
def test_local_layer_repo():
|
||||
def test_local_layer_repo(device):
|
||||
# Fetch a kernel to the local cache.
|
||||
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
|
||||
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
linear = TorchLinearWithCounter(32, 32).to(device)
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": LocalLayerRepository(
|
||||
device: LocalLayerRepository(
|
||||
# install_kernel will give the fully-resolved path.
|
||||
repo_path=path.parent.parent,
|
||||
package_name=package_name,
|
||||
@ -255,7 +316,7 @@ def test_local_layer_repo():
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
X = torch.randn(10, 32, device=device)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
@ -323,6 +384,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
}
|
||||
|
||||
extra_mapping1 = {
|
||||
@ -340,6 +402,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
|
||||
@ -358,6 +421,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
@ -371,6 +435,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
@ -393,6 +458,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
@ -404,6 +470,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
}
|
||||
|
||||
|
||||
@ -923,7 +990,7 @@ def test_kernel_modes_cross_fallback():
|
||||
assert linear.n_calls == 2
|
||||
|
||||
|
||||
def test_layer_versions():
|
||||
def test_layer_versions(device):
|
||||
@use_kernel_forward_from_hub("Version")
|
||||
class Version(nn.Module):
|
||||
def forward(self) -> str:
|
||||
@ -934,20 +1001,20 @@ def test_layer_versions():
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<1.0.0",
|
||||
@ -955,13 +1022,13 @@ def test_layer_versions():
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<0.2.0",
|
||||
@ -969,13 +1036,13 @@ def test_layer_versions():
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.1.0,<0.2.0",
|
||||
@ -983,13 +1050,13 @@ def test_layer_versions():
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.2.0",
|
||||
@ -998,13 +1065,13 @@ def test_layer_versions():
|
||||
}
|
||||
):
|
||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||
use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
revision="v0.1.0",
|
||||
|
Reference in New Issue
Block a user