mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 20:46:42 +08:00
Test examples in docstrings using mktestdocs (#118)
Also adjust examples so that they are correct.
This commit is contained in:
19
flake.lock
generated
19
flake.lock
generated
@ -58,32 +58,33 @@
|
|||||||
"nixpkgs": "nixpkgs"
|
"nixpkgs": "nixpkgs"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1750775451,
|
"lastModified": 1753693795,
|
||||||
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
|
"narHash": "sha256-bKT2Alrbo9wPMINaKVkgnnp/5dp6TpQe80OgpHOz2jI=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "hf-nix",
|
"repo": "hf-nix",
|
||||||
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
|
"rev": "168df9452b3e6c961fdce0115899a1f8b0947a73",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
|
"ref": "mktestdocs-0.2.5",
|
||||||
"repo": "hf-nix",
|
"repo": "hf-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1747820358,
|
"lastModified": 1752785354,
|
||||||
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
||||||
"owner": "danieldk",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "nixos",
|
||||||
"ref": "cudatoolkit-12.9-kernel-builder",
|
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
|
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
inputs = {
|
inputs = {
|
||||||
hf-nix.url = "github:huggingface/hf-nix";
|
hf-nix.url = "github:huggingface/hf-nix/mktestdocs-0.2.5";
|
||||||
nixpkgs.follows = "hf-nix/nixpkgs";
|
nixpkgs.follows = "hf-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
};
|
};
|
||||||
@ -40,6 +40,7 @@
|
|||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
docutils
|
docutils
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
|
mktestdocs
|
||||||
pytest
|
pytest
|
||||||
pytest-benchmark
|
pytest-benchmark
|
||||||
pyyaml
|
pyyaml
|
||||||
|
@ -24,11 +24,12 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"mypy >= 1.15.0",
|
"mktestdocs>=0.2.5",
|
||||||
"pytest >=8",
|
"mypy>=1.15.0",
|
||||||
|
"pytest>=8",
|
||||||
# Whatever version is compatible with pytest.
|
# Whatever version is compatible with pytest.
|
||||||
"pytest-benchmark",
|
"pytest-benchmark",
|
||||||
"torch >=2.5",
|
"torch>=2.5",
|
||||||
"types-pyyaml"
|
"types-pyyaml"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
markers =
|
markers =
|
||||||
|
cuda_only: marks tests that should only hosts with CUDA GPUs
|
||||||
darwin_only: marks tests that should only run on macOS
|
darwin_only: marks tests that should only run on macOS
|
||||||
linux_only: marks tests that should only run on Linux
|
|
@ -2,6 +2,7 @@ from kernels.layer import (
|
|||||||
CUDAProperties,
|
CUDAProperties,
|
||||||
Device,
|
Device,
|
||||||
LayerRepository,
|
LayerRepository,
|
||||||
|
LockedLayerRepository,
|
||||||
Mode,
|
Mode,
|
||||||
kernelize,
|
kernelize,
|
||||||
register_kernel_mapping,
|
register_kernel_mapping,
|
||||||
@ -22,6 +23,7 @@ __all__ = [
|
|||||||
"CUDAProperties",
|
"CUDAProperties",
|
||||||
"Device",
|
"Device",
|
||||||
"LayerRepository",
|
"LayerRepository",
|
||||||
|
"LockedLayerRepository",
|
||||||
"Mode",
|
"Mode",
|
||||||
"get_kernel",
|
"get_kernel",
|
||||||
"get_local_kernel",
|
"get_local_kernel",
|
||||||
|
@ -208,15 +208,15 @@ class LayerRepository:
|
|||||||
|
|
||||||
# Reference a specific layer by revision
|
# Reference a specific layer by revision
|
||||||
layer_repo = LayerRepository(
|
layer_repo = LayerRepository(
|
||||||
repo_id="username/my-kernel",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="MyLayer",
|
layer_name="SiluAndMul",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference a layer by version constraint
|
# Reference a layer by version constraint
|
||||||
layer_repo_versioned = LayerRepository(
|
layer_repo_versioned = LayerRepository(
|
||||||
repo_id="username/my-kernel",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="MyLayer",
|
layer_name="SiluAndMul",
|
||||||
version=">=1.0.0,<2.0.0"
|
version=">=0.0.3,<0.1"
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@ -419,22 +419,36 @@ def use_kernel_mapping(
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from kernels import use_kernel_forward_from_hub
|
||||||
from kernels import use_kernel_mapping, LayerRepository, Device
|
from kernels import use_kernel_mapping, LayerRepository, Device
|
||||||
|
from kernels import kernelize
|
||||||
|
|
||||||
# Define a mapping
|
# Define a mapping
|
||||||
mapping = {
|
mapping = {
|
||||||
"LayerNorm": {
|
"SiluAndMul": {
|
||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
repo_id="username/experimental-kernels",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="FastLayerNorm"
|
layer_name="SiluAndMul",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("SiluAndMul")
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
return F.silu(x[..., :d]) * x[..., d:]
|
||||||
|
|
||||||
|
model = SiluAndMul()
|
||||||
|
|
||||||
# Use the mapping for the duration of the context.
|
# Use the mapping for the duration of the context.
|
||||||
with use_kernel_mapping(mapping):
|
with use_kernel_mapping(mapping):
|
||||||
# kernelize uses the temporary mapping
|
# kernelize uses the temporary mapping
|
||||||
model = kernelize(model)
|
model = kernelize(model, device="cuda")
|
||||||
|
|
||||||
# Outside the context, original mappings are restored
|
# Outside the context, original mappings are restored
|
||||||
```
|
```
|
||||||
@ -463,6 +477,7 @@ def register_kernel_mapping(
|
|||||||
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
|
inherit_mapping: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Register a global mapping between layer names and their corresponding kernel implementations.
|
Register a global mapping between layer names and their corresponding kernel implementations.
|
||||||
@ -474,6 +489,9 @@ def register_kernel_mapping(
|
|||||||
mapping (`Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]`):
|
mapping (`Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]`):
|
||||||
The kernel mapping to register globally. Maps layer names to device-specific kernels.
|
The kernel mapping to register globally. Maps layer names to device-specific kernels.
|
||||||
The mapping can specify different kernels for different modes (training, inference, etc.).
|
The mapping can specify different kernels for different modes (training, inference, etc.).
|
||||||
|
inherit_mapping (`bool`, *optional*, defaults to `True`):
|
||||||
|
When `True`, the current mapping will be extended by `mapping`. When `False`, the existing mappings
|
||||||
|
are erased before adding `mapping`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
@ -509,6 +527,9 @@ def register_kernel_mapping(
|
|||||||
register_kernel_mapping(advanced_mapping)
|
register_kernel_mapping(advanced_mapping)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if not inherit_mapping:
|
||||||
|
_KERNEL_MAPPING.set({})
|
||||||
|
|
||||||
# Merge with existing mappings.
|
# Merge with existing mappings.
|
||||||
for new_kernel, new_device_repos in mapping.items():
|
for new_kernel, new_device_repos in mapping.items():
|
||||||
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
||||||
@ -626,19 +647,23 @@ def kernelize(
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
from kernels import kernelize, Mode, register_kernel_mapping, LayerRepository
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("LayerNorm")
|
from kernels import kernelize, Mode, register_kernel_mapping, LayerRepository
|
||||||
class LayerNorm(nn.Module):
|
from kernels import use_kernel_forward_from_hub
|
||||||
...
|
|
||||||
|
@use_kernel_forward_from_hub("SiluAndMul")
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
return F.silu(x[..., :d]) * x[..., d:]
|
||||||
|
|
||||||
# First register some kernel mappings
|
|
||||||
mapping = {
|
mapping = {
|
||||||
"LayerNorm": {
|
"LayerNorm": {
|
||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
repo_id="username/fast-kernels",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="FastLayerNorm"
|
layer_name="SiluAndMul",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -646,9 +671,8 @@ def kernelize(
|
|||||||
|
|
||||||
# Create and kernelize a model
|
# Create and kernelize a model
|
||||||
model = nn.Sequential(
|
model = nn.Sequential(
|
||||||
nn.Linear(768, 768),
|
nn.Linear(1024, 2048, device="cuda"),
|
||||||
LayerNorm(768),
|
SiluAndMul(),
|
||||||
nn.Linear(768, 768)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Kernelize for inference
|
# Kernelize for inference
|
||||||
@ -776,22 +800,25 @@ def use_kernel_forward_from_hub(layer_name: str):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
from kernels import use_kernel_forward_from_hub
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from kernels import use_kernel_forward_from_hub, kernelize
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("MyCustomLayer")
|
@use_kernel_forward_from_hub("MyCustomLayer")
|
||||||
class MyCustomLayer(nn.Module):
|
class MyCustomLayer(nn.Module):
|
||||||
def __init__(self, hidden_size):
|
def __init__(self, hidden_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor):
|
||||||
# original implementation
|
# original implementation
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# The layer can now be kernelized
|
|
||||||
model = MyCustomLayer(768)
|
model = MyCustomLayer(768)
|
||||||
kernelized_model = kernelize(model)
|
|
||||||
|
# The layer can now be kernelized:
|
||||||
|
# model = kernelize(model, device="cuda")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -221,9 +221,13 @@ def get_kernel(
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
|
import torch
|
||||||
from kernels import get_kernel
|
from kernels import get_kernel
|
||||||
kernel = get_kernel("username/my-kernel")
|
|
||||||
result = kernel.kernel_function(input_data)
|
activation = get_kernel("kernels-community/activation")
|
||||||
|
x = torch.randn(10, 20, device="cuda")
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
result = activation.silu_and_mul(out, x)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
revision = select_revision_or_version(repo_id, revision, version)
|
revision = select_revision_or_version(repo_id, revision, version)
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_setup(item):
|
def pytest_runtest_setup(item):
|
||||||
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
|
if "cuda_only" in item.keywords and not has_cuda:
|
||||||
pytest.skip("skipping Linux-only test on non-Linux platform")
|
pytest.skip("skipping CUDA-only test on host without CUDA")
|
||||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||||
|
@ -34,7 +34,7 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_fast(kernel, device):
|
def test_gelu_fast(kernel, device):
|
||||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
@ -50,7 +50,7 @@ def test_gelu_fast(kernel, device):
|
|||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_local_kernel(local_kernel, device):
|
def test_local_kernel(local_kernel, device):
|
||||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
@ -74,7 +74,7 @@ def test_relu_metal(metal_kernel, dtype):
|
|||||||
assert torch.allclose(y, torch.relu(x))
|
assert torch.allclose(y, torch.relu(x))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"kernel_exists",
|
"kernel_exists",
|
||||||
[
|
[
|
||||||
@ -110,7 +110,7 @@ def test_version():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_universal_kernel(universal_kernel):
|
def test_universal_kernel(universal_kernel):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||||
|
@ -16,21 +16,21 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_small(kernel, device, benchmark):
|
def test_gelu_small(kernel, device, benchmark):
|
||||||
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_medium(kernel, device, benchmark):
|
def test_gelu_medium(kernel, device, benchmark):
|
||||||
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_large(kernel, device, benchmark):
|
def test_gelu_large(kernel, device, benchmark):
|
||||||
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
|
49
tests/test_doctest.py
Normal file
49
tests/test_doctest.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mktestdocs import check_docstring, get_codeblock_members
|
||||||
|
|
||||||
|
import kernels
|
||||||
|
|
||||||
|
|
||||||
|
def all_public_functions():
|
||||||
|
function_list = inspect.getmembers(kernels, inspect.isfunction)
|
||||||
|
return [func for _, func in function_list]
|
||||||
|
|
||||||
|
|
||||||
|
def all_public_classes():
|
||||||
|
class_list = inspect.getmembers(kernels, inspect.isclass)
|
||||||
|
return [cls for _, cls in class_list]
|
||||||
|
|
||||||
|
|
||||||
|
def all_public_class_members():
|
||||||
|
members = get_codeblock_members(*all_public_classes())
|
||||||
|
return members
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"func",
|
||||||
|
all_public_functions(),
|
||||||
|
ids=lambda d: d.__name__,
|
||||||
|
)
|
||||||
|
def test_func_docstring(func):
|
||||||
|
check_docstring(obj=func)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"cls",
|
||||||
|
all_public_classes(),
|
||||||
|
ids=lambda d: d.__name__,
|
||||||
|
)
|
||||||
|
def test_class_docstring(cls):
|
||||||
|
check_docstring(obj=cls)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"member", all_public_class_members(), ids=lambda d: d.__qualname__
|
||||||
|
)
|
||||||
|
def test_member_docstring(member):
|
||||||
|
check_docstring(member)
|
@ -27,7 +27,7 @@ def test_download_all_hash_validation():
|
|||||||
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_load_locked():
|
def test_load_locked():
|
||||||
project_dir = Path(__file__).parent / "kernel_locking"
|
project_dir = Path(__file__).parent / "kernel_locking"
|
||||||
# Also validates that hashing works correctly.
|
# Also validates that hashing works correctly.
|
||||||
|
@ -13,12 +13,12 @@ from kernels import (
|
|||||||
kernelize,
|
kernelize,
|
||||||
register_kernel_mapping,
|
register_kernel_mapping,
|
||||||
use_kernel_forward_from_hub,
|
use_kernel_forward_from_hub,
|
||||||
|
use_kernel_mapping,
|
||||||
)
|
)
|
||||||
from kernels.layer import (
|
from kernels.layer import (
|
||||||
_KERNEL_MAPPING,
|
_KERNEL_MAPPING,
|
||||||
CUDAProperties,
|
CUDAProperties,
|
||||||
_validate_layer,
|
_validate_layer,
|
||||||
use_kernel_mapping,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel_layer_mapping = {
|
kernel_layer_mapping = {
|
||||||
@ -102,7 +102,7 @@ def test_arg_kinds():
|
|||||||
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||||
def test_hub_forward(cls, device):
|
def test_hub_forward(cls, device):
|
||||||
@ -124,7 +124,7 @@ def test_hub_forward(cls, device):
|
|||||||
assert silu_and_mul_with_kernel.n_calls == 1
|
assert silu_and_mul_with_kernel.n_calls == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_capability():
|
def test_capability():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
@ -183,7 +183,7 @@ def test_layer_fallback_works():
|
|||||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
def test_torch_compile_layer_without_fallback(cls, device):
|
def test_torch_compile_layer_without_fallback(cls, device):
|
||||||
@ -214,7 +214,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
|
|||||||
torch.testing.assert_close(Y_compiled, Y)
|
torch.testing.assert_close(Y_compiled, Y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
def test_torch_compile_layer_with_fallback(cls, device):
|
def test_torch_compile_layer_with_fallback(cls, device):
|
||||||
@ -237,8 +237,11 @@ def test_torch_compile_layer_with_fallback(cls, device):
|
|||||||
torch.testing.assert_close(Y_compiled, Y)
|
torch.testing.assert_close(Y_compiled, Y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_mapping_contexts():
|
def test_mapping_contexts():
|
||||||
|
# Make sure we start from scratch.
|
||||||
|
register_kernel_mapping(kernel_layer_mapping, inherit_mapping=False)
|
||||||
|
|
||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
@ -351,7 +354,7 @@ def test_validate_kernel_layer():
|
|||||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_invalid_mode_for_mapping_rejected():
|
def test_invalid_mode_for_mapping_rejected():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
@ -371,7 +374,7 @@ def test_invalid_mode_for_mapping_rejected():
|
|||||||
kernelize(linear, mode=Mode.TRAINING)
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes():
|
def test_kernel_modes():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
@ -515,7 +518,7 @@ def test_kernel_modes():
|
|||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_fallback_used_when_training():
|
def test_fallback_used_when_training():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
@ -580,7 +583,7 @@ def test_invalid_mode_rejected():
|
|||||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes_inference():
|
def test_kernel_modes_inference():
|
||||||
"""Test inference-specific fallback scenarios."""
|
"""Test inference-specific fallback scenarios."""
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
@ -677,7 +680,7 @@ def test_kernel_modes_inference():
|
|||||||
assert linear.n_calls == 4
|
assert linear.n_calls == 4
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes_mixed():
|
def test_kernel_modes_mixed():
|
||||||
"""Test mixed training and inference kernel scenarios."""
|
"""Test mixed training and inference kernel scenarios."""
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
@ -767,7 +770,7 @@ def test_kernel_modes_mixed():
|
|||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes_cross_fallback():
|
def test_kernel_modes_cross_fallback():
|
||||||
"""Test cross-mode fallback scenarios from inference to training modes."""
|
"""Test cross-mode fallback scenarios from inference to training modes."""
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
Reference in New Issue
Block a user