mirror of
https://github.com/huggingface/kernels.git
synced 2025-11-06 07:04:32 +08:00
Compare commits
4 Commits
release-0.
...
kernel-lay
| Author | SHA1 | Date | |
|---|---|---|---|
| 1789f8da4c | |||
| 6c00194680 | |||
| d6b51eefb7 | |||
| d383fdd4b4 |
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -51,7 +51,10 @@ jobs:
|
||||
run: uv run mypy src/kernels
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
uv run pytest tests
|
||||
|
||||
- name: Check kernel conversion
|
||||
run: |
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
title: Kernels
|
||||
- local: api/layers
|
||||
title: Layers
|
||||
- local: cli
|
||||
title: kernels CLI
|
||||
title: API Reference
|
||||
- sections:
|
||||
- local: kernel-requirements
|
||||
|
||||
@ -21,6 +21,22 @@ activation.gelu_fast(y, x)
|
||||
print(y)
|
||||
```
|
||||
|
||||
### Using version bounds
|
||||
|
||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
||||
You can specify which version to download using Python version specifiers:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
|
||||
```
|
||||
|
||||
This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
|
||||
is strongly recommended to specify a version bound, since a kernel author
|
||||
might push incompatible changes to the `main` branch.
|
||||
|
||||
## Checking Kernel Availability
|
||||
|
||||
You can check if a specific kernel is available for your environment:
|
||||
|
||||
15
docs/source/cli.md
Normal file
15
docs/source/cli.md
Normal file
@ -0,0 +1,15 @@
|
||||
# Kernels CLI Reference
|
||||
|
||||
## Main Functions
|
||||
|
||||
### kernels upload
|
||||
|
||||
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
|
||||
your kernel builds to the Hub.
|
||||
|
||||
**Notes**:
|
||||
|
||||
* This will take care of creating a repository on the Hub with the `repo_id` provided.
|
||||
* If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
|
||||
being uploaded, it will attempt to delete the files existing under it.
|
||||
* Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
|
||||
@ -34,6 +34,8 @@ Kernels are versioned on the Hub using Git tags. Version tags must be of
|
||||
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
||||
to resolve the version constraints.
|
||||
|
||||
We recommend using [semver](https://semver.org/) to version kernels.
|
||||
|
||||
## Native Python module
|
||||
|
||||
Kernels will typically contain a native Python module with precompiled
|
||||
@ -50,7 +52,6 @@ have dynamic library dependencies outside:
|
||||
for compatibility with Python 3.9 and later.
|
||||
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
||||
This means that the extension **must not** use symbols versions higher than:
|
||||
|
||||
- GLIBC 2.28
|
||||
- GLIBCXX 3.4.24
|
||||
- CXXABI 1.3.11
|
||||
|
||||
@ -157,6 +157,33 @@ with use_kernel_mapping(kernel_layer_mapping):
|
||||
This ensures that the mapping is not active anymore outside the
|
||||
`with`-scope.
|
||||
|
||||
### Using version bounds
|
||||
|
||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
||||
You can specify which version of the kernel to download using Python version
|
||||
specifiers:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
version=">=0.0.4,<0.1.0",
|
||||
),
|
||||
"rocm": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
version=">=0.0.4,<0.1.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
|
||||
least 4. It is strongly recommended to specify a version bound, since a
|
||||
kernel author might push incompatible changes to the `main` branch.
|
||||
|
||||
### Registering kernels for specific modes
|
||||
|
||||
You might want to register two different kernels for a particular layer,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.10.1"
|
||||
version = "0.10.1.dev0"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
|
||||
@ -3,3 +3,5 @@ markers =
|
||||
cuda_only: marks tests that should only hosts with CUDA GPUs
|
||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
||||
token: enable tests that require a write token
|
||||
|
||||
@ -4,6 +4,8 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
|
||||
from kernels.compat import tomllib
|
||||
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||
@ -31,6 +33,24 @@ def main():
|
||||
)
|
||||
download_parser.set_defaults(func=download_kernels)
|
||||
|
||||
upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub")
|
||||
upload_parser.add_argument(
|
||||
"kernel_dir",
|
||||
type=Path,
|
||||
help="Directory of the kernel build",
|
||||
)
|
||||
upload_parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
help="Repository ID to use to upload to the Hugging Face Hub",
|
||||
)
|
||||
upload_parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="If the repository should be private.",
|
||||
)
|
||||
upload_parser.set_defaults(func=upload_kernels)
|
||||
|
||||
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
|
||||
lock_parser.add_argument(
|
||||
"project_dir",
|
||||
@ -153,6 +173,33 @@ def lock_kernels(args):
|
||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||
|
||||
|
||||
def upload_kernels(args):
|
||||
kernel_dir = Path(args.kernel_dir).resolve()
|
||||
build_dir = kernel_dir / "build"
|
||||
if not kernel_dir.is_dir():
|
||||
raise ValueError(f"{kernel_dir} is not a directory")
|
||||
if not build_dir.is_dir():
|
||||
raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
|
||||
|
||||
repo_id = create_repo(
|
||||
repo_id=args.repo_id, private=args.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
delete_patterns: set[str] = set()
|
||||
for build_variant in build_dir.iterdir():
|
||||
if build_variant.is_dir():
|
||||
delete_patterns.add(f"{build_variant.name}/**")
|
||||
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=build_dir,
|
||||
path_in_repo="build",
|
||||
delete_patterns=list(delete_patterns),
|
||||
commit_message="Build uploaded using `kernels`.",
|
||||
)
|
||||
print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.")
|
||||
|
||||
|
||||
class _JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if dataclasses.is_dataclass(o):
|
||||
|
||||
@ -87,7 +87,7 @@ class Device:
|
||||
|
||||
Args:
|
||||
type (`str`):
|
||||
The device type (e.g., "cuda", "mps", "rocm").
|
||||
The device type (e.g., "cuda", "mps", "rocm", "xpu").
|
||||
properties ([`CUDAProperties`], *optional*):
|
||||
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
||||
|
||||
@ -106,6 +106,9 @@ class Device:
|
||||
|
||||
# MPS device for Apple Silicon
|
||||
mps_device = Device(type="mps")
|
||||
|
||||
# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
|
||||
xpu_device = Device(type="xpu")
|
||||
```
|
||||
"""
|
||||
|
||||
@ -125,6 +128,8 @@ class Device:
|
||||
return _ROCMRepos()
|
||||
elif self.type == "mps":
|
||||
return _MPSRepos()
|
||||
elif self.type == "xpu":
|
||||
return _XPURepos()
|
||||
else:
|
||||
raise ValueError(f"Unknown device type: {self.type}")
|
||||
|
||||
@ -311,7 +316,7 @@ class LayerRepository:
|
||||
return hash((self.layer_name, self._repo_id, self._revision, self._version))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
|
||||
|
||||
|
||||
class LocalLayerRepository:
|
||||
@ -367,7 +372,7 @@ class LocalLayerRepository:
|
||||
return hash((self.layer_name, self._repo_path, self._package_name))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"`{self._repo_path}` (package: {self._package_name}) for layer `{self.layer_name}`"
|
||||
return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.layer_name}`"
|
||||
|
||||
|
||||
class LockedLayerRepository:
|
||||
@ -422,7 +427,7 @@ class LockedLayerRepository:
|
||||
return hash((self.layer_name, self._repo_id))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
|
||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
|
||||
|
||||
|
||||
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
||||
@ -447,6 +452,26 @@ class _DeviceRepos(ABC):
|
||||
...
|
||||
|
||||
|
||||
class _XPURepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._repos = {}
|
||||
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
return self._repos
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
if device.type != "xpu":
|
||||
raise ValueError(f"Device type must be 'xpu', got {device.type}")
|
||||
|
||||
self._repos = repos
|
||||
|
||||
|
||||
class _MPSRepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
@ -531,7 +556,7 @@ class _ROCMRepos(_DeviceRepos):
|
||||
|
||||
def _validate_device_type(device_type: str) -> None:
|
||||
"""Validate that the device type is supported."""
|
||||
supported_devices = {"cuda", "rocm", "mps"}
|
||||
supported_devices = {"cuda", "rocm", "mps", "xpu"}
|
||||
if device_type not in supported_devices:
|
||||
raise ValueError(
|
||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||
@ -789,7 +814,7 @@ def kernelize(
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
|
||||
`torch.compile`.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
|
||||
The device type will be inferred from the model parameters when not provided.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||
@ -995,7 +1020,7 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]:
|
||||
return layer
|
||||
|
||||
|
||||
def _validate_layer(*, check_cls, cls):
|
||||
def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
|
||||
import torch.nn as nn
|
||||
|
||||
# The layer must have at least have the following properties: (1) it
|
||||
@ -1004,12 +1029,12 @@ def _validate_layer(*, check_cls, cls):
|
||||
# methods.
|
||||
|
||||
if not issubclass(cls, nn.Module):
|
||||
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
||||
raise TypeError(f"Layer `{cls.__name__}` is not a Torch layer.")
|
||||
|
||||
# We verify statelessness by checking that the does not have its own
|
||||
# constructor (since the constructor could add member variables)...
|
||||
if cls.__init__ is not nn.Module.__init__:
|
||||
raise TypeError("Layer must not override nn.Module constructor.")
|
||||
raise TypeError(f"{repo} must not override nn.Module constructor.")
|
||||
|
||||
# ... or predefined member variables.
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
@ -1017,7 +1042,9 @@ def _validate_layer(*, check_cls, cls):
|
||||
difference = cls_members - torch_module_members
|
||||
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
||||
if not difference <= {"can_torch_compile", "has_backward"}:
|
||||
raise TypeError("Layer must not contain additional members.")
|
||||
raise TypeError(
|
||||
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
|
||||
)
|
||||
|
||||
# Check whether the forward signatures are similar.
|
||||
params = inspect.signature(cls.forward).parameters
|
||||
@ -1025,13 +1052,13 @@ def _validate_layer(*, check_cls, cls):
|
||||
|
||||
if len(params) != len(ref_params):
|
||||
raise TypeError(
|
||||
"Forward signature does not match: different number of arguments."
|
||||
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
|
||||
)
|
||||
|
||||
for param, ref_param in zip(params.values(), ref_params.values()):
|
||||
if param.kind != ref_param.kind:
|
||||
raise TypeError(
|
||||
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||
)
|
||||
|
||||
|
||||
@ -1148,7 +1175,7 @@ def _get_layer_memoize(
|
||||
return layer
|
||||
|
||||
layer = _get_kernel_layer(repo)
|
||||
_validate_layer(check_cls=module_class, cls=layer)
|
||||
_validate_layer(check_cls=module_class, cls=layer, repo=repo)
|
||||
_CACHED_LAYER[repo] = layer
|
||||
|
||||
return layer
|
||||
|
||||
@ -13,6 +13,19 @@ has_rocm = (
|
||||
and torch.version.hip is not None
|
||||
and torch.cuda.device_count() > 0
|
||||
)
|
||||
has_xpu = (
|
||||
hasattr(torch.version, "xpu")
|
||||
and torch.version.xpu is not None
|
||||
and torch.xpu.device_count() > 0
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--token",
|
||||
action="store_true",
|
||||
help="run tests that require a token with write permissions",
|
||||
)
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
@ -22,3 +35,7 @@ def pytest_runtest_setup(item):
|
||||
pytest.skip("skipping ROCm-only test on host without ROCm")
|
||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||
if "xpu_only" in item.keywords and not has_xpu:
|
||||
pytest.skip("skipping XPU-only test on host without XPU")
|
||||
if "token" in item.keywords and not item.config.getoption("--token"):
|
||||
pytest.skip("need --token option to run this test")
|
||||
|
||||
88
tests/test_kernel_upload.py
Normal file
88
tests/test_kernel_upload.py
Normal file
@ -0,0 +1,88 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import model_info
|
||||
|
||||
from kernels.cli import upload_kernels
|
||||
|
||||
REPO_ID = "kernels-test/kernels-upload-test"
|
||||
|
||||
PY_CONTENT = """\
|
||||
#!/usr/bin/env python3
|
||||
|
||||
def main():
|
||||
print("Hello from torch-universal!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadArgs:
|
||||
kernel_dir: None
|
||||
repo_id: None
|
||||
private: False
|
||||
|
||||
|
||||
def next_filename(path: Path) -> Path:
|
||||
"""
|
||||
Given a path like foo_2050.py, return foo_2051.py.
|
||||
"""
|
||||
m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
|
||||
if not m:
|
||||
raise ValueError(
|
||||
f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
|
||||
)
|
||||
|
||||
prefix, number, suffix = m.groups()
|
||||
new_number = str(int(number) + 1).zfill(len(number))
|
||||
return path.with_name(f"{prefix}{new_number}{suffix}")
|
||||
|
||||
|
||||
def get_filename_to_change(repo_filenames):
|
||||
for f in repo_filenames:
|
||||
if "foo" in f and f.endswith(".py"):
|
||||
filename_to_change = os.path.basename(f)
|
||||
break
|
||||
assert filename_to_change
|
||||
return filename_to_change
|
||||
|
||||
|
||||
def get_filenames_from_a_repo(repo_id: str) -> List[str]:
|
||||
try:
|
||||
repo_info = model_info(repo_id=repo_id, files_metadata=True)
|
||||
repo_siblings = repo_info.siblings
|
||||
if repo_siblings is not None:
|
||||
return [f.rfilename for f in repo_siblings]
|
||||
else:
|
||||
raise ValueError("No repo siblings found.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error connecting to the Hub: {e}.")
|
||||
|
||||
|
||||
@pytest.mark.token
|
||||
def test_kernel_upload_deletes_as_expected():
|
||||
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||
filename_to_change = get_filename_to_change(repo_filenames)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = f"{tmpdir}/build/torch-universal/upload_test"
|
||||
build_dir = Path(path)
|
||||
build_dir.mkdir(parents=True, exist_ok=True)
|
||||
changed_filename = next_filename(Path(filename_to_change))
|
||||
script_path = build_dir / changed_filename
|
||||
script_path.write_text(PY_CONTENT)
|
||||
upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
|
||||
|
||||
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||
assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
|
||||
assert not any(
|
||||
str(filename_to_change) in k for k in repo_filenames
|
||||
), f"{repo_filenames=}"
|
||||
@ -46,11 +46,37 @@ kernel_layer_mapping = {
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
},
|
||||
"LigerRMSNorm": {
|
||||
"xpu": LayerRepository(
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm", # Triton
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.n_calls += 1
|
||||
var = x.pow(2).mean(-1, keepdim=True)
|
||||
x_norm = x * torch.rsqrt(var + self.variance_epsilon)
|
||||
return x_norm * self.weight
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("LigerRMSNorm")
|
||||
class RMSNormWithKernel(RMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -90,6 +116,16 @@ class TorchLinearWithCounter(nn.Linear):
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
|
||||
pytest.skip("No CUDA or XPU")
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@use_kernel_forward_from_hub("ArgKind")
|
||||
class ArgKind(nn.Module):
|
||||
@ -147,6 +183,31 @@ def test_hub_forward_rocm():
|
||||
assert silu_and_mul_with_kernel.n_calls in [0, 1]
|
||||
|
||||
|
||||
@pytest.mark.xpu_only
|
||||
def test_hub_forward_xpu():
|
||||
torch.manual_seed(0)
|
||||
|
||||
hidden_size = 1024
|
||||
weight = torch.ones(hidden_size, device="xpu")
|
||||
rms_norm = RMSNorm(weight).to("xpu")
|
||||
X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32)
|
||||
Y = rms_norm(X)
|
||||
|
||||
rms_norm_with_kernel = kernelize(
|
||||
RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu"
|
||||
)
|
||||
Y_kernel = rms_norm_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
assert rms_norm.n_calls == 1
|
||||
assert rms_norm_with_kernel.n_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
|
||||
reason="Skip on xpu devices",
|
||||
)
|
||||
def test_rocm_kernel_mapping():
|
||||
"""Test that ROCm shorthand device mapping works correctly."""
|
||||
kernel_layer_mapping = {
|
||||
@ -234,16 +295,16 @@ def test_layer_fallback_works():
|
||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||
|
||||
|
||||
def test_local_layer_repo():
|
||||
def test_local_layer_repo(device):
|
||||
# Fetch a kernel to the local cache.
|
||||
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
|
||||
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
linear = TorchLinearWithCounter(32, 32).to(device)
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": LocalLayerRepository(
|
||||
device: LocalLayerRepository(
|
||||
# install_kernel will give the fully-resolved path.
|
||||
repo_path=path.parent.parent,
|
||||
package_name=package_name,
|
||||
@ -255,7 +316,7 @@ def test_local_layer_repo():
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
X = torch.randn(10, 32, device=device)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
@ -323,6 +384,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
}
|
||||
|
||||
extra_mapping1 = {
|
||||
@ -340,6 +402,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
|
||||
@ -358,6 +421,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
@ -371,6 +435,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
@ -393,6 +458,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
@ -404,6 +470,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"SiluAndMulNoCompile",
|
||||
"LigerRMSNorm",
|
||||
}
|
||||
|
||||
|
||||
@ -413,26 +480,43 @@ def test_validate_kernel_layer():
|
||||
super().__init__(*args, **kwargs)
|
||||
self.foo = 42
|
||||
|
||||
with pytest.raises(TypeError, match="not override"):
|
||||
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
||||
def stub_repo(layer):
|
||||
return LayerRepository(
|
||||
repo_id="kernels-test/nonexisting", layer_name=layer.__name__
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="`kernels-test/nonexisting`.*layer `BadLayer` must not override",
|
||||
):
|
||||
_validate_layer(cls=BadLayer, check_cls=SiluAndMul, repo=stub_repo(BadLayer))
|
||||
|
||||
class BadLayer2(nn.Module):
|
||||
foo: int = 42
|
||||
|
||||
with pytest.raises(TypeError, match="not contain additional members"):
|
||||
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="`kernels-test/nonexisting`.*layer `BadLayer2` must not contain.*SiluAndMul",
|
||||
):
|
||||
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul, repo=stub_repo(BadLayer2))
|
||||
|
||||
class BadLayer3(nn.Module):
|
||||
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
||||
|
||||
with pytest.raises(TypeError, match="different number of arguments"):
|
||||
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer3` does not match `SiluAndMul`: different number of arguments",
|
||||
):
|
||||
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul, repo=stub_repo(BadLayer3))
|
||||
|
||||
class BadLayer4(nn.Module):
|
||||
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer4` does not match `SiluAndMul`: different kind of arguments",
|
||||
):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul, repo=stub_repo(BadLayer4))
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
@ -923,7 +1007,7 @@ def test_kernel_modes_cross_fallback():
|
||||
assert linear.n_calls == 2
|
||||
|
||||
|
||||
def test_layer_versions():
|
||||
def test_layer_versions(device):
|
||||
@use_kernel_forward_from_hub("Version")
|
||||
class Version(nn.Module):
|
||||
def forward(self) -> str:
|
||||
@ -934,20 +1018,20 @@ def test_layer_versions():
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<1.0.0",
|
||||
@ -955,13 +1039,13 @@ def test_layer_versions():
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<0.2.0",
|
||||
@ -969,13 +1053,13 @@ def test_layer_versions():
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.1.0,<0.2.0",
|
||||
@ -983,13 +1067,13 @@ def test_layer_versions():
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.2.0",
|
||||
@ -998,13 +1082,13 @@ def test_layer_versions():
|
||||
}
|
||||
):
|
||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||
use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
Device(type=device): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
revision="v0.1.0",
|
||||
|
||||
Reference in New Issue
Block a user