mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-28 09:52:19 +08:00
Compare commits
19 Commits
stateful-l
...
upload-hub
| Author | SHA1 | Date | |
|---|---|---|---|
| 81fb5d34bb | |||
| fd237b04bd | |||
| 0eb07f198c | |||
| 620bf75864 | |||
| 8f78116b87 | |||
| f1782d1914 | |||
| f6c901205c | |||
| 6899e4bfe1 | |||
| ad9cba28f7 | |||
| 2f1986e01a | |||
| ab607022c0 | |||
| 02cbff1d0f | |||
| d2d8f77d97 | |||
| 421f09e08a | |||
| e2d43815c1 | |||
| 7ee9660d2c | |||
| b56106966e | |||
| 1720baac7d | |||
| a6dc55ddb1 |
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@ -51,9 +51,8 @@ jobs:
|
|||||||
run: uv run mypy src/kernels
|
run: uv run mypy src/kernels
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
env:
|
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
|
||||||
run: |
|
run: |
|
||||||
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
uv run pytest tests
|
uv run pytest tests
|
||||||
|
|
||||||
- name: Check kernel conversion
|
- name: Check kernel conversion
|
||||||
|
|||||||
@ -62,6 +62,7 @@ the Hub.
|
|||||||
- [Using layers](docs/source/layers.md)
|
- [Using layers](docs/source/layers.md)
|
||||||
- [Locking kernel/layer versions](docs/source/locking.md)
|
- [Locking kernel/layer versions](docs/source/locking.md)
|
||||||
- [Environment variables](docs/source/env.md)
|
- [Environment variables](docs/source/env.md)
|
||||||
|
- [Using kernels in a Docker container](docs/source/docker.md)
|
||||||
- [Kernel requirements](docs/source/kernel-requirements.md)
|
- [Kernel requirements](docs/source/kernel-requirements.md)
|
||||||
- [Frequently Asked Questions](docs/source/faq.md)
|
- [Frequently Asked Questions](docs/source/faq.md)
|
||||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
|
|||||||
@ -21,8 +21,6 @@
|
|||||||
title: Kernels
|
title: Kernels
|
||||||
- local: api/layers
|
- local: api/layers
|
||||||
title: Layers
|
title: Layers
|
||||||
- local: cli
|
|
||||||
title: Kernels CLI
|
|
||||||
title: API Reference
|
title: API Reference
|
||||||
- sections:
|
- sections:
|
||||||
- local: kernel-requirements
|
- local: kernel-requirements
|
||||||
|
|||||||
@ -21,22 +21,6 @@ activation.gelu_fast(y, x)
|
|||||||
print(y)
|
print(y)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using version bounds
|
|
||||||
|
|
||||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
|
||||||
You can specify which version to download using Python version specifiers:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from kernels import get_kernel
|
|
||||||
|
|
||||||
activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
|
|
||||||
```
|
|
||||||
|
|
||||||
This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
|
|
||||||
is strongly recommended to specify a version bound, since a kernel author
|
|
||||||
might push incompatible changes to the `main` branch.
|
|
||||||
|
|
||||||
## Checking Kernel Availability
|
## Checking Kernel Availability
|
||||||
|
|
||||||
You can check if a specific kernel is available for your environment:
|
You can check if a specific kernel is available for your environment:
|
||||||
|
|||||||
@ -1,41 +0,0 @@
|
|||||||
# Kernels CLI Reference
|
|
||||||
|
|
||||||
## Main Functions
|
|
||||||
|
|
||||||
### kernels to-wheel
|
|
||||||
|
|
||||||
We strongly recommend downloading kernels from the Hub using the `kernels`
|
|
||||||
package, since this comes with large [benefits](index.md) over using Python
|
|
||||||
wheels. That said, some projects may require deployment of kernels as
|
|
||||||
wheels. The `kernels` utility provides a simple solution to this. You can
|
|
||||||
convert any Hub kernel into a set of wheels with the `to-wheel` command:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ kernels to-wheel drbh/img2grey 1.1.2
|
|
||||||
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl
|
|
||||||
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
|
||||||
```
|
|
||||||
|
|
||||||
### kernels upload
|
|
||||||
|
|
||||||
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
|
|
||||||
your kernel builds to the Hub.
|
|
||||||
|
|
||||||
**Notes**:
|
|
||||||
|
|
||||||
- This will take care of creating a repository on the Hub with the `repo_id` provided.
|
|
||||||
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
|
|
||||||
being uploaded, it will attempt to delete the files existing under it.
|
|
||||||
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
|
|
||||||
|
|
||||||
@ -34,8 +34,6 @@ Kernels are versioned on the Hub using Git tags. Version tags must be of
|
|||||||
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
||||||
to resolve the version constraints.
|
to resolve the version constraints.
|
||||||
|
|
||||||
We recommend using [semver](https://semver.org/) to version kernels.
|
|
||||||
|
|
||||||
## Native Python module
|
## Native Python module
|
||||||
|
|
||||||
Kernels will typically contain a native Python module with precompiled
|
Kernels will typically contain a native Python module with precompiled
|
||||||
@ -52,12 +50,13 @@ have dynamic library dependencies outside:
|
|||||||
for compatibility with Python 3.9 and later.
|
for compatibility with Python 3.9 and later.
|
||||||
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
||||||
This means that the extension **must not** use symbols versions higher than:
|
This means that the extension **must not** use symbols versions higher than:
|
||||||
|
|
||||||
- GLIBC 2.28
|
- GLIBC 2.28
|
||||||
- GLIBCXX 3.4.24
|
- GLIBCXX 3.4.24
|
||||||
- CXXABI 1.3.11
|
- CXXABI 1.3.11
|
||||||
- GCC 7.0.0
|
- GCC 7.0.0
|
||||||
|
|
||||||
These requirements can be checked with the ABI checker (see below).
|
These requirement can be checked with the ABI checker (see below).
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ the Hub can replace the `forward` method of an existing layer for a certain
|
|||||||
device type. This makes it possible to provide more performant kernels for
|
device type. This makes it possible to provide more performant kernels for
|
||||||
existing layers.
|
existing layers.
|
||||||
|
|
||||||
See [Kernel requirements](kernel-requirements.md) for more information on the
|
See [Kernel requirements](kernel-requirements.md) for more information the
|
||||||
requirements of Hub layers.
|
requirements of Hub layers.
|
||||||
|
|
||||||
## Making a layer extensible with kernels from the hub
|
## Making a layer extensible with kernels from the hub
|
||||||
@ -111,7 +111,7 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=
|
|||||||
|
|
||||||
This can be useful if you want to guarantee that Hub kernels are used.
|
This can be useful if you want to guarantee that Hub kernels are used.
|
||||||
|
|
||||||
### Inspecting which kernels are used
|
### Inspecting kernels which kernels are used
|
||||||
|
|
||||||
The kernels that are used are logged at the `INFO` level by `kernelize`.
|
The kernels that are used are logged at the `INFO` level by `kernelize`.
|
||||||
See the [Python logging](https://docs.python.org/3/library/logging.html)
|
See the [Python logging](https://docs.python.org/3/library/logging.html)
|
||||||
@ -157,33 +157,6 @@ with use_kernel_mapping(kernel_layer_mapping):
|
|||||||
This ensures that the mapping is not active anymore outside the
|
This ensures that the mapping is not active anymore outside the
|
||||||
`with`-scope.
|
`with`-scope.
|
||||||
|
|
||||||
### Using version bounds
|
|
||||||
|
|
||||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
|
||||||
You can specify which version of the kernel to download using Python version
|
|
||||||
specifiers:
|
|
||||||
|
|
||||||
```python
|
|
||||||
kernel_layer_mapping = {
|
|
||||||
"SiluAndMul": {
|
|
||||||
"cuda": LayerRepository(
|
|
||||||
repo_id="kernels-community/activation",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
version=">=0.0.4,<0.1.0",
|
|
||||||
),
|
|
||||||
"rocm": LayerRepository(
|
|
||||||
repo_id="kernels-community/activation",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
version=">=0.0.4,<0.1.0",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
|
|
||||||
least 4. It is strongly recommended to specify a version bound, since a
|
|
||||||
kernel author might push incompatible changes to the `main` branch.
|
|
||||||
|
|
||||||
### Registering kernels for specific modes
|
### Registering kernels for specific modes
|
||||||
|
|
||||||
You might want to register two different kernels for a particular layer,
|
You might want to register two different kernels for a particular layer,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "kernels"
|
name = "kernels"
|
||||||
version = "0.10.1.dev0"
|
version = "0.10.0.dev0"
|
||||||
description = "Download compute kernels"
|
description = "Download compute kernels"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||||
|
|||||||
@ -3,5 +3,3 @@ markers =
|
|||||||
cuda_only: marks tests that should only hosts with CUDA GPUs
|
cuda_only: marks tests that should only hosts with CUDA GPUs
|
||||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||||
darwin_only: marks tests that should only run on macOS
|
darwin_only: marks tests that should only run on macOS
|
||||||
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
|
||||||
token: enable tests that require a write token
|
|
||||||
|
|||||||
@ -1,7 +1,3 @@
|
|||||||
import importlib.metadata
|
|
||||||
|
|
||||||
__version__ = importlib.metadata.version("kernels")
|
|
||||||
|
|
||||||
from kernels.layer import (
|
from kernels.layer import (
|
||||||
CUDAProperties,
|
CUDAProperties,
|
||||||
Device,
|
Device,
|
||||||
@ -25,7 +21,6 @@ from kernels.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"__version__",
|
|
||||||
"CUDAProperties",
|
"CUDAProperties",
|
||||||
"Device",
|
"Device",
|
||||||
"LayerRepository",
|
"LayerRepository",
|
||||||
|
|||||||
@ -17,10 +17,8 @@ from types import MethodType, ModuleType
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Dict,
|
Dict,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Set,
|
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
@ -89,7 +87,7 @@ class Device:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (`str`):
|
type (`str`):
|
||||||
The device type (e.g., "cuda", "mps", "rocm", "xpu").
|
The device type (e.g., "cuda", "mps", "rocm").
|
||||||
properties ([`CUDAProperties`], *optional*):
|
properties ([`CUDAProperties`], *optional*):
|
||||||
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
||||||
|
|
||||||
@ -108,9 +106,6 @@ class Device:
|
|||||||
|
|
||||||
# MPS device for Apple Silicon
|
# MPS device for Apple Silicon
|
||||||
mps_device = Device(type="mps")
|
mps_device = Device(type="mps")
|
||||||
|
|
||||||
# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
|
|
||||||
xpu_device = Device(type="xpu")
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -130,8 +125,6 @@ class Device:
|
|||||||
return _ROCMRepos()
|
return _ROCMRepos()
|
||||||
elif self.type == "mps":
|
elif self.type == "mps":
|
||||||
return _MPSRepos()
|
return _MPSRepos()
|
||||||
elif self.type == "xpu":
|
|
||||||
return _XPURepos()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown device type: {self.type}")
|
raise ValueError(f"Unknown device type: {self.type}")
|
||||||
|
|
||||||
@ -318,7 +311,7 @@ class LayerRepository:
|
|||||||
return hash((self.layer_name, self._repo_id, self._revision, self._version))
|
return hash((self.layer_name, self._repo_id, self._revision, self._version))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
|
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
|
||||||
|
|
||||||
|
|
||||||
class LocalLayerRepository:
|
class LocalLayerRepository:
|
||||||
@ -374,7 +367,7 @@ class LocalLayerRepository:
|
|||||||
return hash((self.layer_name, self._repo_path, self._package_name))
|
return hash((self.layer_name, self._repo_path, self._package_name))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.layer_name}`"
|
return f"`{self._repo_path}` (package: {self._package_name}) for layer `{self.layer_name}`"
|
||||||
|
|
||||||
|
|
||||||
class LockedLayerRepository:
|
class LockedLayerRepository:
|
||||||
@ -429,7 +422,7 @@ class LockedLayerRepository:
|
|||||||
return hash((self.layer_name, self._repo_id))
|
return hash((self.layer_name, self._repo_id))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
|
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
|
||||||
|
|
||||||
|
|
||||||
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
||||||
@ -454,26 +447,6 @@ class _DeviceRepos(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class _XPURepos(_DeviceRepos):
|
|
||||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self._repos = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def repos(
|
|
||||||
self,
|
|
||||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
|
||||||
return self._repos
|
|
||||||
|
|
||||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
|
||||||
if device.type != "xpu":
|
|
||||||
raise ValueError(f"Device type must be 'xpu', got {device.type}")
|
|
||||||
|
|
||||||
self._repos = repos
|
|
||||||
|
|
||||||
|
|
||||||
class _MPSRepos(_DeviceRepos):
|
class _MPSRepos(_DeviceRepos):
|
||||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||||
|
|
||||||
@ -558,7 +531,7 @@ class _ROCMRepos(_DeviceRepos):
|
|||||||
|
|
||||||
def _validate_device_type(device_type: str) -> None:
|
def _validate_device_type(device_type: str) -> None:
|
||||||
"""Validate that the device type is supported."""
|
"""Validate that the device type is supported."""
|
||||||
supported_devices = {"cuda", "rocm", "mps", "xpu"}
|
supported_devices = {"cuda", "rocm", "mps"}
|
||||||
if device_type not in supported_devices:
|
if device_type not in supported_devices:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||||
@ -816,7 +789,7 @@ def kernelize(
|
|||||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
|
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
|
||||||
`torch.compile`.
|
`torch.compile`.
|
||||||
device (`Union[str, torch.device]`, *optional*):
|
device (`Union[str, torch.device]`, *optional*):
|
||||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
|
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
|
||||||
The device type will be inferred from the model parameters when not provided.
|
The device type will be inferred from the model parameters when not provided.
|
||||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||||
@ -870,14 +843,10 @@ def kernelize(
|
|||||||
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
|
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = _find_device(model)
|
device_type = _find_device(model)
|
||||||
device_type = _find_device_type(model)
|
|
||||||
elif isinstance(device, str):
|
elif isinstance(device, str):
|
||||||
_validate_device_type(device)
|
_validate_device_type(device)
|
||||||
import torch
|
|
||||||
|
|
||||||
device_type = Device(type=device)
|
device_type = Device(type=device)
|
||||||
device = torch.device(device)
|
|
||||||
else:
|
else:
|
||||||
device_type = Device(device.type)
|
device_type = Device(device.type)
|
||||||
|
|
||||||
@ -890,7 +859,7 @@ def kernelize(
|
|||||||
layer_name = module_class.kernel_layer_name
|
layer_name = module_class.kernel_layer_name
|
||||||
|
|
||||||
if _DISABLE_KERNEL_MAPPING:
|
if _DISABLE_KERNEL_MAPPING:
|
||||||
_replace_forward(device, module, module_class)
|
_replace_forward(module, module_class)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
|
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
|
||||||
@ -904,7 +873,7 @@ def kernelize(
|
|||||||
)
|
)
|
||||||
if not use_fallback:
|
if not use_fallback:
|
||||||
raise ValueError(f"No layer mapping for `{layer_name}`")
|
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||||
_replace_forward(device, module, module_class)
|
_replace_forward(module, module_class)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get kernel options for the device
|
# Get kernel options for the device
|
||||||
@ -915,7 +884,7 @@ def kernelize(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
||||||
)
|
)
|
||||||
_replace_forward(device, module, module_class)
|
_replace_forward(module, module_class)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
repos = property_repos.repos
|
repos = property_repos.repos
|
||||||
@ -925,7 +894,7 @@ def kernelize(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
|
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
|
||||||
)
|
)
|
||||||
_replace_forward(device, module, module_class)
|
_replace_forward(module, module_class)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
repo_with_mode = _select_repository(
|
repo_with_mode = _select_repository(
|
||||||
@ -938,7 +907,7 @@ def kernelize(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No repository for `{layer_name}` for configuration mode={mode}"
|
f"No repository for `{layer_name}` for configuration mode={mode}"
|
||||||
)
|
)
|
||||||
_replace_forward(device, module, module_class)
|
_replace_forward(module, module_class)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
repo, repo_mode = repo_with_mode
|
repo, repo_mode = repo_with_mode
|
||||||
@ -957,7 +926,6 @@ def kernelize(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_conditionally_replace_forward(
|
_conditionally_replace_forward(
|
||||||
device=device,
|
|
||||||
module=module,
|
module=module,
|
||||||
layer=layer,
|
layer=layer,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
@ -1027,7 +995,7 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]:
|
|||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
|
def _validate_layer(*, check_cls, cls):
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# The layer must have at least have the following properties: (1) it
|
# The layer must have at least have the following properties: (1) it
|
||||||
@ -1036,48 +1004,34 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
|
|||||||
# methods.
|
# methods.
|
||||||
|
|
||||||
if not issubclass(cls, nn.Module):
|
if not issubclass(cls, nn.Module):
|
||||||
raise TypeError(f"Layer `{cls.__name__}` is not a Torch layer.")
|
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
||||||
|
|
||||||
# We verify statelessness by checking that the does not have its own
|
# We verify statelessness by checking that the does not have its own
|
||||||
# constructor (since the constructor could add member variables)...
|
# constructor (since the constructor could add member variables)...
|
||||||
if cls.__init__ is not nn.Module.__init__:
|
if cls.__init__ is not nn.Module.__init__:
|
||||||
raise TypeError(f"{repo} must not override nn.Module constructor.")
|
raise TypeError("Layer must not override nn.Module constructor.")
|
||||||
|
|
||||||
# ... or predefined member variables.
|
# ... or predefined member variables.
|
||||||
unique_members = _unique_layer_members(cls)
|
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||||
|
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||||
|
difference = cls_members - torch_module_members
|
||||||
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
||||||
if not unique_members <= {
|
if not difference <= {"can_torch_compile", "has_backward"}:
|
||||||
"can_torch_compile",
|
raise TypeError("Layer must not contain additional members.")
|
||||||
"create_state",
|
|
||||||
"has_backward",
|
|
||||||
"forward_with_state",
|
|
||||||
}:
|
|
||||||
raise TypeError(
|
|
||||||
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check whether the forward signatures are similar.
|
# Check whether the forward signatures are similar.
|
||||||
|
params = inspect.signature(cls.forward).parameters
|
||||||
ref_params = inspect.signature(check_cls.forward).parameters
|
ref_params = inspect.signature(check_cls.forward).parameters
|
||||||
|
|
||||||
params: Mapping[str, inspect.Parameter]
|
|
||||||
if _is_stateful_layer(cls):
|
|
||||||
params = inspect.signature(cls.forward_with_state).parameters
|
|
||||||
# Get rid of the mappingproxy.
|
|
||||||
params = params.copy()
|
|
||||||
# Remove the state to be able to compare with forward.
|
|
||||||
del params["state"]
|
|
||||||
else:
|
|
||||||
params = inspect.signature(cls.forward).parameters
|
|
||||||
|
|
||||||
if len(params) != len(ref_params):
|
if len(params) != len(ref_params):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
|
"Forward signature does not match: different number of arguments."
|
||||||
)
|
)
|
||||||
|
|
||||||
for param, ref_param in zip(params.values(), ref_params.values()):
|
for param, ref_param in zip(params.values(), ref_params.values()):
|
||||||
if param.kind != ref_param.kind:
|
if param.kind != ref_param.kind:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1093,7 +1047,7 @@ def _is_rocm_platform():
|
|||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
def _find_device(model: "nn.Module") -> torch.device:
|
def _find_device(model: "nn.Module") -> Device:
|
||||||
try:
|
try:
|
||||||
param = next(model.parameters())
|
param = next(model.parameters())
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
@ -1101,13 +1055,7 @@ def _find_device(model: "nn.Module") -> torch.device:
|
|||||||
"Cannot determine model device, provide as `device` argument to `kernelize`."
|
"Cannot determine model device, provide as `device` argument to `kernelize`."
|
||||||
)
|
)
|
||||||
|
|
||||||
return param.device
|
dev_type = param.device.type
|
||||||
|
|
||||||
|
|
||||||
def _find_device_type(model: "nn.Module") -> Device:
|
|
||||||
device = _find_device(model)
|
|
||||||
|
|
||||||
dev_type = device.type
|
|
||||||
if dev_type == "cuda":
|
if dev_type == "cuda":
|
||||||
# Refine based on actual platform
|
# Refine based on actual platform
|
||||||
if _is_rocm_platform():
|
if _is_rocm_platform():
|
||||||
@ -1128,7 +1076,6 @@ def _find_capability() -> int:
|
|||||||
|
|
||||||
def _conditionally_replace_forward(
|
def _conditionally_replace_forward(
|
||||||
*,
|
*,
|
||||||
device: "torch.device",
|
|
||||||
module: "nn.Module",
|
module: "nn.Module",
|
||||||
layer: Type["nn.Module"],
|
layer: Type["nn.Module"],
|
||||||
mode: Mode,
|
mode: Mode,
|
||||||
@ -1154,25 +1101,15 @@ def _conditionally_replace_forward(
|
|||||||
logging.info("Layer does not support torch.compile, using fallback")
|
logging.info("Layer does not support torch.compile, using fallback")
|
||||||
if needs_fallback_for_backward:
|
if needs_fallback_for_backward:
|
||||||
logging.info("Layer does not support backward, using fallback")
|
logging.info("Layer does not support backward, using fallback")
|
||||||
_replace_forward(device, module, module_class)
|
_replace_forward(module, module_class)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Available kernel does not support mode: {mode}")
|
raise ValueError(f"Available kernel does not support mode: {mode}")
|
||||||
else:
|
else:
|
||||||
_replace_forward(device, module, layer)
|
_replace_forward(module, layer)
|
||||||
|
|
||||||
|
|
||||||
def _replace_forward(
|
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
|
||||||
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
|
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
||||||
):
|
|
||||||
if _is_stateful_layer(layer):
|
|
||||||
state = layer.create_state(device, module) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return layer.forward_with_state(self, state, *args, **kwargs)
|
|
||||||
|
|
||||||
module.forward = MethodType(forward, module)
|
|
||||||
else:
|
|
||||||
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_layer_has_mode(
|
def _validate_layer_has_mode(
|
||||||
@ -1211,25 +1148,7 @@ def _get_layer_memoize(
|
|||||||
return layer
|
return layer
|
||||||
|
|
||||||
layer = _get_kernel_layer(repo)
|
layer = _get_kernel_layer(repo)
|
||||||
_validate_layer(check_cls=module_class, cls=layer, repo=repo)
|
_validate_layer(check_cls=module_class, cls=layer)
|
||||||
_CACHED_LAYER[repo] = layer
|
_CACHED_LAYER[repo] = layer
|
||||||
|
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
|
||||||
cls_members = {name for name, _ in inspect.getmembers(layer)}
|
|
||||||
return cls_members - torch_module_members
|
|
||||||
|
|
||||||
|
|
||||||
def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
|
|
||||||
unique = _unique_layer_members(layer)
|
|
||||||
is_stateful = "forward_with_state" in unique
|
|
||||||
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
|
|
||||||
raise TypeError(
|
|
||||||
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
|
|
||||||
)
|
|
||||||
return is_stateful
|
|
||||||
|
|||||||
@ -46,9 +46,8 @@ def build_variant() -> str:
|
|||||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
compute_framework = "metal"
|
compute_framework = "metal"
|
||||||
elif torch.version.xpu is not None:
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
version = torch.version.xpu
|
compute_framework = "xpu"
|
||||||
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
|
|
||||||
else:
|
else:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
||||||
|
|||||||
@ -13,19 +13,6 @@ has_rocm = (
|
|||||||
and torch.version.hip is not None
|
and torch.version.hip is not None
|
||||||
and torch.cuda.device_count() > 0
|
and torch.cuda.device_count() > 0
|
||||||
)
|
)
|
||||||
has_xpu = (
|
|
||||||
hasattr(torch.version, "xpu")
|
|
||||||
and torch.version.xpu is not None
|
|
||||||
and torch.xpu.device_count() > 0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--token",
|
|
||||||
action="store_true",
|
|
||||||
help="run tests that require a token with write permissions",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_setup(item):
|
def pytest_runtest_setup(item):
|
||||||
@ -35,7 +22,3 @@ def pytest_runtest_setup(item):
|
|||||||
pytest.skip("skipping ROCm-only test on host without ROCm")
|
pytest.skip("skipping ROCm-only test on host without ROCm")
|
||||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||||
if "xpu_only" in item.keywords and not has_xpu:
|
|
||||||
pytest.skip("skipping XPU-only test on host without XPU")
|
|
||||||
if "token" in item.keywords and not item.config.getoption("--token"):
|
|
||||||
pytest.skip("need --token option to run this test")
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
|
||||||
from huggingface_hub import model_info
|
from huggingface_hub import model_info
|
||||||
|
|
||||||
from kernels.cli import upload_kernels
|
from kernels.cli import upload_kernels
|
||||||
@ -67,7 +66,6 @@ def get_filenames_from_a_repo(repo_id: str) -> List[str]:
|
|||||||
logging.error(f"Error connecting to the Hub: {e}.")
|
logging.error(f"Error connecting to the Hub: {e}.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.token
|
|
||||||
def test_kernel_upload_deletes_as_expected():
|
def test_kernel_upload_deletes_as_expected():
|
||||||
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||||
filename_to_change = get_filename_to_change(repo_filenames)
|
filename_to_change = get_filename_to_change(repo_filenames)
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.testing import assert_close
|
|
||||||
|
|
||||||
from kernels import (
|
from kernels import (
|
||||||
CUDAProperties,
|
CUDAProperties,
|
||||||
@ -47,37 +46,11 @@ kernel_layer_mapping = {
|
|||||||
layer_name="SiluAndMul",
|
layer_name="SiluAndMul",
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
"LigerRMSNorm": {
|
|
||||||
"xpu": LayerRepository(
|
|
||||||
repo_id="kernels-community/liger_kernels",
|
|
||||||
layer_name="LigerRMSNorm", # Triton
|
|
||||||
)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
register_kernel_mapping(kernel_layer_mapping)
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
# Used to check that we called hub kernel.
|
|
||||||
self.n_calls = 0
|
|
||||||
self.weight = nn.Parameter(weight)
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
self.n_calls += 1
|
|
||||||
var = x.pow(2).mean(-1, keepdim=True)
|
|
||||||
x_norm = x * torch.rsqrt(var + self.variance_epsilon)
|
|
||||||
return x_norm * self.weight
|
|
||||||
|
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("LigerRMSNorm")
|
|
||||||
class RMSNormWithKernel(RMSNorm):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -117,16 +90,6 @@ class TorchLinearWithCounter(nn.Linear):
|
|||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def device():
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return "cuda"
|
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
return "xpu"
|
|
||||||
|
|
||||||
pytest.skip("No CUDA or XPU")
|
|
||||||
|
|
||||||
|
|
||||||
def test_arg_kinds():
|
def test_arg_kinds():
|
||||||
@use_kernel_forward_from_hub("ArgKind")
|
@use_kernel_forward_from_hub("ArgKind")
|
||||||
class ArgKind(nn.Module):
|
class ArgKind(nn.Module):
|
||||||
@ -184,31 +147,6 @@ def test_hub_forward_rocm():
|
|||||||
assert silu_and_mul_with_kernel.n_calls in [0, 1]
|
assert silu_and_mul_with_kernel.n_calls in [0, 1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xpu_only
|
|
||||||
def test_hub_forward_xpu():
|
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
hidden_size = 1024
|
|
||||||
weight = torch.ones(hidden_size, device="xpu")
|
|
||||||
rms_norm = RMSNorm(weight).to("xpu")
|
|
||||||
X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32)
|
|
||||||
Y = rms_norm(X)
|
|
||||||
|
|
||||||
rms_norm_with_kernel = kernelize(
|
|
||||||
RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu"
|
|
||||||
)
|
|
||||||
Y_kernel = rms_norm_with_kernel(X)
|
|
||||||
|
|
||||||
torch.testing.assert_close(Y_kernel, Y)
|
|
||||||
|
|
||||||
assert rms_norm.n_calls == 1
|
|
||||||
assert rms_norm_with_kernel.n_calls == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
|
|
||||||
reason="Skip on xpu devices",
|
|
||||||
)
|
|
||||||
def test_rocm_kernel_mapping():
|
def test_rocm_kernel_mapping():
|
||||||
"""Test that ROCm shorthand device mapping works correctly."""
|
"""Test that ROCm shorthand device mapping works correctly."""
|
||||||
kernel_layer_mapping = {
|
kernel_layer_mapping = {
|
||||||
@ -296,16 +234,16 @@ def test_layer_fallback_works():
|
|||||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
|
||||||
|
|
||||||
def test_local_layer_repo(device):
|
def test_local_layer_repo():
|
||||||
# Fetch a kernel to the local cache.
|
# Fetch a kernel to the local cache.
|
||||||
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
|
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
|
||||||
|
|
||||||
linear = TorchLinearWithCounter(32, 32).to(device)
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Linear": {
|
"Linear": {
|
||||||
device: LocalLayerRepository(
|
"cuda": LocalLayerRepository(
|
||||||
# install_kernel will give the fully-resolved path.
|
# install_kernel will give the fully-resolved path.
|
||||||
repo_path=path.parent.parent,
|
repo_path=path.parent.parent,
|
||||||
package_name=package_name,
|
package_name=package_name,
|
||||||
@ -317,52 +255,11 @@ def test_local_layer_repo(device):
|
|||||||
):
|
):
|
||||||
kernelize(linear, mode=Mode.INFERENCE)
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
|
||||||
X = torch.randn(10, 32, device=device)
|
X = torch.randn(10, 32, device="cuda")
|
||||||
linear(X)
|
linear(X)
|
||||||
assert linear.n_calls == 0
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
|
||||||
def test_stateful_layer(device):
|
|
||||||
@use_kernel_forward_from_hub("ReluWithHiddenSize")
|
|
||||||
class ReluWithHiddenSize(nn.Module):
|
|
||||||
hidden_size: int
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return F.relu(x)
|
|
||||||
|
|
||||||
model = ReluWithHiddenSize(hidden_size=64).to(device)
|
|
||||||
x = torch.randn((32, 64), device=device)
|
|
||||||
y_ref = model(x)
|
|
||||||
|
|
||||||
with use_kernel_mapping(
|
|
||||||
{
|
|
||||||
"ReluWithHiddenSize": {
|
|
||||||
"cuda": LayerRepository(
|
|
||||||
repo_id="kernels-test/state-test",
|
|
||||||
layer_name="StatefulReLU",
|
|
||||||
),
|
|
||||||
"xpu": LayerRepository(
|
|
||||||
repo_id="kernels-test/state-test",
|
|
||||||
layer_name="StatefulReLU",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
inherit_mapping=False,
|
|
||||||
):
|
|
||||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
|
|
||||||
|
|
||||||
y = model(x)
|
|
||||||
assert_close(y, y_ref)
|
|
||||||
|
|
||||||
model = torch.compile(model, fullgraph=True)
|
|
||||||
y = model(x)
|
|
||||||
assert_close(y, y_ref)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.cuda_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
@ -426,7 +323,6 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
"LigerRMSNorm",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extra_mapping1 = {
|
extra_mapping1 = {
|
||||||
@ -444,7 +340,6 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
"LigerRMSNorm",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -463,7 +358,6 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
"LigerRMSNorm",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
@ -477,7 +371,6 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
"LigerRMSNorm",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
@ -500,7 +393,6 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
"LigerRMSNorm",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
@ -512,7 +404,6 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
"LigerRMSNorm",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -522,43 +413,26 @@ def test_validate_kernel_layer():
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.foo = 42
|
self.foo = 42
|
||||||
|
|
||||||
def stub_repo(layer):
|
with pytest.raises(TypeError, match="not override"):
|
||||||
return LayerRepository(
|
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
||||||
repo_id="kernels-test/nonexisting", layer_name=layer.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
TypeError,
|
|
||||||
match="`kernels-test/nonexisting`.*layer `BadLayer` must not override",
|
|
||||||
):
|
|
||||||
_validate_layer(cls=BadLayer, check_cls=SiluAndMul, repo=stub_repo(BadLayer))
|
|
||||||
|
|
||||||
class BadLayer2(nn.Module):
|
class BadLayer2(nn.Module):
|
||||||
foo: int = 42
|
foo: int = 42
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(TypeError, match="not contain additional members"):
|
||||||
TypeError,
|
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
||||||
match="`kernels-test/nonexisting`.*layer `BadLayer2` must not contain.*SiluAndMul",
|
|
||||||
):
|
|
||||||
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul, repo=stub_repo(BadLayer2))
|
|
||||||
|
|
||||||
class BadLayer3(nn.Module):
|
class BadLayer3(nn.Module):
|
||||||
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(TypeError, match="different number of arguments"):
|
||||||
TypeError,
|
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
||||||
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer3` does not match `SiluAndMul`: different number of arguments",
|
|
||||||
):
|
|
||||||
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul, repo=stub_repo(BadLayer3))
|
|
||||||
|
|
||||||
class BadLayer4(nn.Module):
|
class BadLayer4(nn.Module):
|
||||||
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||||
TypeError,
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||||
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer4` does not match `SiluAndMul`: different kind of arguments",
|
|
||||||
):
|
|
||||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul, repo=stub_repo(BadLayer4))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.cuda_only
|
@pytest.mark.cuda_only
|
||||||
@ -1049,7 +923,7 @@ def test_kernel_modes_cross_fallback():
|
|||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
def test_layer_versions(device):
|
def test_layer_versions():
|
||||||
@use_kernel_forward_from_hub("Version")
|
@use_kernel_forward_from_hub("Version")
|
||||||
class Version(nn.Module):
|
class Version(nn.Module):
|
||||||
def forward(self) -> str:
|
def forward(self) -> str:
|
||||||
@ -1060,20 +934,20 @@ def test_layer_versions(device):
|
|||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type=device): LayerRepository(
|
Device(type="cuda"): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
assert version() == "0.2.0"
|
assert version() == "0.2.0"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type=device): LayerRepository(
|
Device(type="cuda"): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version="<1.0.0",
|
version="<1.0.0",
|
||||||
@ -1081,13 +955,13 @@ def test_layer_versions(device):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
assert version() == "0.2.0"
|
assert version() == "0.2.0"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type=device): LayerRepository(
|
Device(type="cuda"): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version="<0.2.0",
|
version="<0.2.0",
|
||||||
@ -1095,13 +969,13 @@ def test_layer_versions(device):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
assert version() == "0.1.1"
|
assert version() == "0.1.1"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type=device): LayerRepository(
|
Device(type="cuda"): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version=">0.1.0,<0.2.0",
|
version=">0.1.0,<0.2.0",
|
||||||
@ -1109,13 +983,13 @@ def test_layer_versions(device):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
assert version() == "0.1.1"
|
assert version() == "0.1.1"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type=device): LayerRepository(
|
Device(type="cuda"): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version=">0.2.0",
|
version=">0.2.0",
|
||||||
@ -1124,13 +998,13 @@ def test_layer_versions(device):
|
|||||||
}
|
}
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||||
kernelize(version, device=device, mode=Mode.INFERENCE)
|
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||||
use_kernel_mapping(
|
use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type=device): LayerRepository(
|
Device(type="cuda"): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
revision="v0.1.0",
|
revision="v0.1.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user