mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
Compare commits
9 Commits
release-0.
...
v0.10.1
Author | SHA1 | Date | |
---|---|---|---|
f03cde8151 | |||
07e5e8481a | |||
88f55d4728 | |||
e801ebf332 | |||
0ae07f05fc | |||
7611021100 | |||
767e7ccf13 | |||
1caa4c1393 | |||
da701bf58a |
@ -84,12 +84,6 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
When the `mode` argument is not specified,
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` is used as the default. This mode
|
||||
aligns most closely with pure PyTorch layers which also support training
|
||||
and `torch.compile`. However, to select the most performant kernels, it
|
||||
is often good to make the mode specific as possible.
|
||||
|
||||
### Kernel device
|
||||
|
||||
Kernels can be registered per device type. For instance, separate `cuda` and
|
||||
@ -157,7 +151,7 @@ used with the `use_kernel_mapping` context manager:
|
||||
```python
|
||||
with use_kernel_mapping(kernel_layer_mapping):
|
||||
# Use the layer for which the mapping is applied.
|
||||
model = kernelize(model)
|
||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
This ensures that the mapping is not active anymore outside the
|
||||
@ -285,12 +279,11 @@ a kernel to a range of ROCm capabilities.
|
||||
The `LocalLayerRepository` class is provided to load a repository from
|
||||
a local directory. For example:
|
||||
|
||||
```
|
||||
```python
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"SiluAndMul": {
|
||||
"cuda": LocalLayerRepository(
|
||||
# install_kernel will give the fully-resolved path.
|
||||
repo_path="/home/daniel/kernels/activation",
|
||||
package_name="activation",
|
||||
layer_name="SiluAndMul",
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.9.0.dev0"
|
||||
version = "0.10.1"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
|
@ -1,3 +1,7 @@
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("kernels")
|
||||
|
||||
from kernels.layer import (
|
||||
CUDAProperties,
|
||||
Device,
|
||||
@ -21,6 +25,7 @@ from kernels.utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"CUDAProperties",
|
||||
"Device",
|
||||
"LayerRepository",
|
||||
|
@ -87,7 +87,7 @@ class Device:
|
||||
|
||||
Args:
|
||||
type (`str`):
|
||||
The device type (e.g., "cuda", "mps", "cpu").
|
||||
The device type (e.g., "cuda", "mps", "rocm").
|
||||
properties ([`CUDAProperties`], *optional*):
|
||||
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
||||
|
||||
@ -531,7 +531,7 @@ class _ROCMRepos(_DeviceRepos):
|
||||
|
||||
def _validate_device_type(device_type: str) -> None:
|
||||
"""Validate that the device type is supported."""
|
||||
supported_devices = {"cuda", "rocm", "mps", "cpu"}
|
||||
supported_devices = {"cuda", "rocm", "mps"}
|
||||
if device_type not in supported_devices:
|
||||
raise ValueError(
|
||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||
@ -578,7 +578,7 @@ def use_kernel_mapping(
|
||||
|
||||
from kernels import use_kernel_forward_from_hub
|
||||
from kernels import use_kernel_mapping, LayerRepository, Device
|
||||
from kernels import kernelize
|
||||
from kernels import Mode, kernelize
|
||||
|
||||
# Define a mapping
|
||||
mapping = {
|
||||
@ -601,7 +601,7 @@ def use_kernel_mapping(
|
||||
# Use the mapping for the duration of the context.
|
||||
with use_kernel_mapping(mapping):
|
||||
# kernelize uses the temporary mapping
|
||||
model = kernelize(model, device="cuda")
|
||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda")
|
||||
|
||||
# Outside the context, original mappings are restored
|
||||
```
|
||||
@ -772,7 +772,7 @@ def _select_repository(
|
||||
def kernelize(
|
||||
model: "nn.Module",
|
||||
*,
|
||||
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||
mode: Mode,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
use_fallback: bool = True,
|
||||
):
|
||||
@ -785,11 +785,11 @@ def kernelize(
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The PyTorch model to kernelize.
|
||||
mode ([`Mode`], *optional*, defaults to `Mode.TRAINING | Mode.TORCH_COMPILE`):
|
||||
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
|
||||
kernelizes the model for training with `torch.compile`.
|
||||
mode ([`Mode`]): The mode that the kernel is going to be used in. For example,
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
|
||||
`torch.compile`.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu".
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
|
||||
The device type will be inferred from the model parameters when not provided.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||
@ -829,7 +829,7 @@ def kernelize(
|
||||
)
|
||||
|
||||
# Kernelize for inference
|
||||
kernelized_model = kernelize(model)
|
||||
kernelized_model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
```
|
||||
"""
|
||||
|
||||
@ -954,7 +954,8 @@ def use_kernel_forward_from_hub(layer_name: str):
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from kernels import use_kernel_forward_from_hub, kernelize
|
||||
from kernels import use_kernel_forward_from_hub
|
||||
from kernels import Mode, kernelize
|
||||
|
||||
@use_kernel_forward_from_hub("MyCustomLayer")
|
||||
class MyCustomLayer(nn.Module):
|
||||
@ -969,7 +970,7 @@ def use_kernel_forward_from_hub(layer_name: str):
|
||||
model = MyCustomLayer(768)
|
||||
|
||||
# The layer can now be kernelized:
|
||||
# model = kernelize(model, device="cuda")
|
||||
# model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda")
|
||||
```
|
||||
"""
|
||||
|
||||
|
@ -46,8 +46,9 @@ def build_variant() -> str:
|
||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||
elif torch.backends.mps.is_available():
|
||||
compute_framework = "metal"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
compute_framework = "xpu"
|
||||
elif torch.version.xpu is not None:
|
||||
version = torch.version.xpu
|
||||
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
||||
@ -248,8 +249,24 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
"""
|
||||
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
# Presume we were given the top level path of the kernel repository.
|
||||
for base_path in [repo_path, repo_path / "build"]:
|
||||
# Prefer the universal variant if it exists.
|
||||
for v in [universal_variant, variant]:
|
||||
package_path = base_path / v / package_name / "__init__.py"
|
||||
if package_path.exists():
|
||||
return import_from_path(package_name, package_path)
|
||||
|
||||
# If we didn't find the package in the repo we may have a explicit
|
||||
# package path.
|
||||
package_path = repo_path / package_name / "__init__.py"
|
||||
if package_path.exists():
|
||||
return import_from_path(package_name, package_path)
|
||||
|
||||
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
|
||||
|
||||
|
||||
def has_kernel(
|
||||
|
@ -10,10 +10,16 @@ def kernel():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel():
|
||||
def local_kernel_path():
|
||||
package_name, path = install_kernel("kernels-community/activation", "main")
|
||||
# Path is the build variant path (build/torch-<...>), so the grandparent
|
||||
# is the kernel repository path.
|
||||
return package_name, path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel(local_kernel_path):
|
||||
package_name, path = local_kernel_path
|
||||
return get_local_kernel(path.parent.parent, package_name)
|
||||
|
||||
|
||||
@ -66,6 +72,39 @@ def test_local_kernel(local_kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_local_kernel_path_types(local_kernel_path, device):
|
||||
package_name, path = local_kernel_path
|
||||
|
||||
# Top-level repo path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
|
||||
kernel = get_local_kernel(path.parent.parent, package_name)
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
kernel.gelu_fast(y, x)
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
# Build directory path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
|
||||
kernel = get_local_kernel(path.parent.parent / "build", package_name)
|
||||
y = torch.empty_like(x)
|
||||
kernel.gelu_fast(y, x)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
# Explicit package path
|
||||
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
|
||||
kernel = get_local_kernel(path, package_name)
|
||||
y = torch.empty_like(x)
|
||||
kernel.gelu_fast(y, x)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.darwin_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_relu_metal(metal_kernel, dtype):
|
||||
|
@ -110,24 +110,20 @@ def test_arg_kinds():
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_hub_forward(cls, device):
|
||||
def test_hub_forward(cls):
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
silu_and_mul = SiluAndMul()
|
||||
X = torch.randn((32, 64), device=device)
|
||||
X = torch.randn((32, 64), device="cuda")
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
||||
silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
assert silu_and_mul.n_calls == 1
|
||||
if device == "cuda":
|
||||
assert silu_and_mul_with_kernel.n_calls == 0
|
||||
else:
|
||||
assert silu_and_mul_with_kernel.n_calls == 1
|
||||
assert silu_and_mul_with_kernel.n_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.rocm_only
|
||||
@ -488,11 +484,6 @@ def test_kernel_modes():
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 2: register a kernel just for training. If no base kernel
|
||||
# layer is registered, we fall back to the original layer.
|
||||
with use_kernel_mapping(
|
||||
@ -522,12 +513,6 @@ def test_kernel_modes():
|
||||
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
||||
assert linear.n_calls == 1
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 3: register a kernel just for training and one for fallback.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
@ -549,23 +534,17 @@ def test_kernel_modes():
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Falls back to TRAINING.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Falls back to the TRAINING kernel.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 1
|
||||
|
||||
# Case 4: register a kernel with two preferences.
|
||||
with use_kernel_mapping(
|
||||
@ -585,22 +564,17 @@ def test_kernel_modes():
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Falls back to the TRAINING | TORCH_COMPILE kernel.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses TRAINING | TORCH_COMPILE kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
|
Reference in New Issue
Block a user