mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
* Add version support to `LayerRepository` * Remove some docs that do not apply * Removed unused member variable
954 lines
29 KiB
Python
954 lines
29 KiB
Python
import sys
|
|
from contextlib import nullcontext
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from kernels import (
|
|
Device,
|
|
LayerRepository,
|
|
Mode,
|
|
kernelize,
|
|
register_kernel_mapping,
|
|
use_kernel_forward_from_hub,
|
|
)
|
|
from kernels.layer import (
|
|
_KERNEL_MAPPING,
|
|
CUDAProperties,
|
|
_validate_layer,
|
|
use_kernel_mapping,
|
|
)
|
|
|
|
kernel_layer_mapping = {
|
|
"SiluAndMul": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-community/activation",
|
|
layer_name="SiluAndMul",
|
|
)
|
|
},
|
|
"SiluAndMulNoCompile": {
|
|
"cuda": LayerRepository(
|
|
repo_id="kernels-test/op-without-fake-test",
|
|
layer_name="SiluAndMul",
|
|
)
|
|
},
|
|
"SiluAndMulStringDevice": {
|
|
"cuda": LayerRepository(
|
|
repo_id="kernels-community/activation",
|
|
layer_name="SiluAndMul",
|
|
)
|
|
},
|
|
}
|
|
|
|
register_kernel_mapping(kernel_layer_mapping)
|
|
|
|
|
|
class SiluAndMul(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Used to check that we called hub kernel.
|
|
self.n_calls = 0
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
self.n_calls += 1
|
|
d = input.shape[-1] // 2
|
|
return F.silu(input[..., :d]) * input[..., d:]
|
|
|
|
|
|
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
|
|
class SiluAndMulNoCompileKernel(SiluAndMul):
|
|
pass
|
|
|
|
|
|
@use_kernel_forward_from_hub("SiluAndMul")
|
|
class SiluAndMulWithKernel(SiluAndMul):
|
|
pass
|
|
|
|
|
|
@use_kernel_forward_from_hub("SiluAndMulStringDevice")
|
|
class SiluAndMulStringDevice(SiluAndMul):
|
|
pass
|
|
|
|
|
|
@use_kernel_forward_from_hub("Linear")
|
|
class TorchLinearWithCounter(nn.Linear):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
# Used to check that we called hub kernel.
|
|
self.n_calls = 0
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
self.n_calls += 1
|
|
return super().forward(input)
|
|
|
|
|
|
def test_arg_kinds():
|
|
@use_kernel_forward_from_hub("ArgKind")
|
|
class ArgKind(nn.Module):
|
|
def forward(
|
|
self,
|
|
arg1,
|
|
arg2,
|
|
*,
|
|
kwarg1,
|
|
kwarg2=42,
|
|
):
|
|
return (arg1, arg2, kwarg1, kwarg2)
|
|
|
|
arg_kind = ArgKind()
|
|
assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
|
|
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
|
def test_hub_forward(cls, device):
|
|
torch.random.manual_seed(0)
|
|
|
|
silu_and_mul = SiluAndMul()
|
|
X = torch.randn((32, 64), device=device)
|
|
Y = silu_and_mul(X)
|
|
|
|
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
|
Y_kernel = silu_and_mul_with_kernel(X)
|
|
|
|
torch.testing.assert_close(Y_kernel, Y)
|
|
|
|
assert silu_and_mul.n_calls == 1
|
|
if device == "cuda":
|
|
assert silu_and_mul_with_kernel.n_calls == 0
|
|
else:
|
|
assert silu_and_mul_with_kernel.n_calls == 1
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_capability():
|
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
Device(
|
|
type="cuda",
|
|
properties=CUDAProperties(
|
|
min_capability=75, max_capability=sys.maxsize
|
|
),
|
|
): LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
|
|
# Check that we called out to the kernel.
|
|
assert linear.n_calls == 0
|
|
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
Device(
|
|
type="cuda",
|
|
properties=CUDAProperties(
|
|
min_capability=sys.maxsize, max_capability=sys.maxsize
|
|
),
|
|
): LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
|
|
# Check that we didn't call out to the kernel because there is
|
|
# is no kernel with a matching capability..
|
|
assert linear.n_calls == 1
|
|
|
|
|
|
def test_layer_fallback_works():
|
|
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
|
|
class SiluAndMulWithKernelFallback(SiluAndMul):
|
|
pass
|
|
|
|
# Check that we don't raise an exception for a non-existing kernel.
|
|
silu_and_mul = SiluAndMulWithKernelFallback()
|
|
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
|
@pytest.mark.parametrize("device", ["cuda"])
|
|
def test_torch_compile_layer_without_fallback(cls, device):
|
|
silu_and_mul = SiluAndMul()
|
|
|
|
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
|
Y = silu_and_mul(X)
|
|
|
|
silu_and_mul_with_kernel = cls()
|
|
silu_and_mul_with_kernel.eval()
|
|
|
|
ctx = (
|
|
pytest.raises(ValueError, match="does not support mode")
|
|
if cls is SiluAndMulNoCompileKernel
|
|
else nullcontext()
|
|
)
|
|
with ctx:
|
|
silu_and_mul_with_kernel = kernelize(
|
|
silu_and_mul_with_kernel,
|
|
device=device,
|
|
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
|
use_fallback=False,
|
|
)
|
|
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
|
|
|
Y_compiled = silu_and_mul_compiled(X)
|
|
|
|
torch.testing.assert_close(Y_compiled, Y)
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
|
@pytest.mark.parametrize("device", ["cuda"])
|
|
def test_torch_compile_layer_with_fallback(cls, device):
|
|
silu_and_mul = SiluAndMul()
|
|
|
|
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
|
Y = silu_and_mul(X)
|
|
|
|
silu_and_mul_with_kernel = cls()
|
|
silu_and_mul_with_kernel.eval()
|
|
silu_and_mul_with_kernel = kernelize(
|
|
silu_and_mul_with_kernel,
|
|
device=device,
|
|
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
|
)
|
|
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
|
|
|
Y_compiled = silu_and_mul_compiled(X)
|
|
|
|
torch.testing.assert_close(Y_compiled, Y)
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_mapping_contexts():
|
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
"SiluAndMul",
|
|
"SiluAndMulStringDevice",
|
|
"SiluAndMulNoCompile",
|
|
}
|
|
|
|
extra_mapping1 = {
|
|
"TestKernel": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-community/activation",
|
|
layer_name="SiluAndMul",
|
|
revision="layers",
|
|
)
|
|
}
|
|
}
|
|
|
|
with use_kernel_mapping(extra_mapping1):
|
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
"SiluAndMul",
|
|
"SiluAndMulStringDevice",
|
|
"SiluAndMulNoCompile",
|
|
"TestKernel",
|
|
}
|
|
|
|
extra_mapping2 = {
|
|
"SiluAndMul": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-community/non-existing",
|
|
layer_name="SiluAndMul",
|
|
revision="layers",
|
|
)
|
|
}
|
|
}
|
|
|
|
with use_kernel_mapping(extra_mapping2):
|
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
"SiluAndMul",
|
|
"SiluAndMulStringDevice",
|
|
"SiluAndMulNoCompile",
|
|
"TestKernel",
|
|
}
|
|
assert (
|
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
|
== "kernels-community/non-existing"
|
|
)
|
|
|
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
"SiluAndMul",
|
|
"SiluAndMulStringDevice",
|
|
"SiluAndMulNoCompile",
|
|
"TestKernel",
|
|
}
|
|
assert (
|
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
|
== "kernels-community/activation"
|
|
)
|
|
|
|
with use_kernel_mapping(extra_mapping2, inherit_mapping=False):
|
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
"SiluAndMul",
|
|
}
|
|
assert (
|
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
|
== "kernels-community/non-existing"
|
|
)
|
|
|
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
"SiluAndMul",
|
|
"SiluAndMulStringDevice",
|
|
"SiluAndMulNoCompile",
|
|
"TestKernel",
|
|
}
|
|
assert (
|
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
|
== "kernels-community/activation"
|
|
)
|
|
|
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
|
"SiluAndMul",
|
|
"SiluAndMulStringDevice",
|
|
"SiluAndMulNoCompile",
|
|
}
|
|
|
|
|
|
def test_validate_kernel_layer():
|
|
class BadLayer(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.foo = 42
|
|
|
|
with pytest.raises(TypeError, match="not override"):
|
|
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
|
|
|
class BadLayer2(nn.Module):
|
|
foo: int = 42
|
|
|
|
with pytest.raises(TypeError, match="not contain additional members"):
|
|
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
|
|
|
class BadLayer3(nn.Module):
|
|
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
|
|
|
with pytest.raises(TypeError, match="different number of arguments"):
|
|
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
|
|
|
class BadLayer4(nn.Module):
|
|
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
|
|
|
with pytest.raises(TypeError, match="different kind of arguments"):
|
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_invalid_mode_for_mapping_rejected():
|
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.TRAINING: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearNoBackward",
|
|
)
|
|
}
|
|
}
|
|
}
|
|
):
|
|
with pytest.raises(ValueError, match="does not support backward"):
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_kernel_modes():
|
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
|
|
# Case 1: layer without further specification, becomes the
|
|
# base layer.
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
kernelize(linear)
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
# Case 2: register a kernel just for training. If no base kernel
|
|
# layer is registered, we fall back to the original layer.
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.TRAINING: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# Training has a kernel, so fallback.
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
|
assert linear.n_calls == 1
|
|
|
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
kernelize(linear)
|
|
linear(X)
|
|
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
|
assert linear.n_calls == 2
|
|
|
|
# Case 3: register a kernel just for training and one for fallback.
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.FALLBACK: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
Mode.TRAINING: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
# Falls back to TRAINING.
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# Falls back to the TRAINING kernel.
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
|
assert linear.n_calls == 2
|
|
|
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
kernelize(linear)
|
|
linear(X)
|
|
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
|
assert linear.n_calls == 2
|
|
|
|
# Case 4: register a kernel with two preferences.
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.TRAINING
|
|
| Mode.TORCH_COMPILE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
# Falls back to the TRAINING | TORCH_COMPILE kernel.
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# Uses TRAINING | TORCH_COMPILE kernel.
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear)
|
|
linear(X)
|
|
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
assert linear.n_calls == 2
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_fallback_used_when_training():
|
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
|
|
# Case 1: kernel with explicit backward support should always
|
|
# use the kernel.
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
linear.train()
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
linear.eval()
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
# Case 2: kernel with implicit backward support should always
|
|
# use the kernel.
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearImplicitBackward",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
linear.train()
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
linear.eval()
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
|
|
def test_invalid_mode_rejected():
|
|
with pytest.raises(ValueError, match="mutually exclusive"):
|
|
_ = Mode.INFERENCE | Mode.TRAINING
|
|
|
|
with pytest.raises(ValueError, match="cannot be combined with other modes"):
|
|
_ = Mode.FALLBACK | Mode.TORCH_COMPILE
|
|
|
|
with pytest.raises(
|
|
ValueError, match="can only be used to register kernel mappings"
|
|
):
|
|
kernelize(torch.nn.Linear(32, 32), mode=Mode.FALLBACK)
|
|
|
|
with pytest.raises(ValueError, match="mode must contain"):
|
|
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_kernel_modes_inference():
|
|
"""Test inference-specific fallback scenarios."""
|
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
|
|
# Case 1: register a kernel just for inference
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.INFERENCE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
|
|
assert linear.n_calls == 1
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# No training kernel, so fallback to original
|
|
assert linear.n_calls == 2
|
|
|
|
# Case 2: register a kernel just for inference + torch.compile
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.INFERENCE
|
|
| Mode.TORCH_COMPILE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
linear(X)
|
|
# INFERENCE falls back to INFERENCE | TORCH_COMPILE kernel
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# No training kernel, so fallback to original
|
|
assert linear.n_calls == 3
|
|
|
|
# Case 3: register both inference kernels
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.INFERENCE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
Mode.INFERENCE
|
|
| Mode.TORCH_COMPILE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
# Uses exact INFERENCE kernel
|
|
assert linear.n_calls == 3
|
|
|
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# Uses exact INFERENCE | TORCH_COMPILE kernel
|
|
assert linear.n_calls == 3
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# No training kernel, so fallback to original
|
|
assert linear.n_calls == 4
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_kernel_modes_mixed():
|
|
"""Test mixed training and inference kernel scenarios."""
|
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
|
|
# Case 1: register both base inference and training kernels
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.INFERENCE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
Mode.TRAINING: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
|
|
assert linear.n_calls == 1
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original
|
|
assert linear.n_calls == 2
|
|
|
|
# Case 2: register all four kernel modes
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.INFERENCE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
Mode.TRAINING: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
Mode.INFERENCE
|
|
| Mode.TORCH_COMPILE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
Mode.TRAINING
|
|
| Mode.TORCH_COMPILE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
# Uses exact INFERENCE kernel
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# Uses exact TRAINING kernel
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# Uses exact INFERENCE | TORCH_COMPILE kernel
|
|
assert linear.n_calls == 2
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# Uses exact TRAINING | TORCH_COMPILE kernel
|
|
assert linear.n_calls == 2
|
|
|
|
|
|
@pytest.mark.linux_only
|
|
def test_kernel_modes_cross_fallback():
|
|
"""Test cross-mode fallback scenarios from inference to training modes."""
|
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
|
|
# Case 1: Only training kernel registered - inference should fall back to training
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.TRAINING: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
# INFERENCE falls back to TRAINING kernel
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# TRAINING uses the kernel directly
|
|
assert linear.n_calls == 0
|
|
|
|
# Case 2: Only training + torch.compile kernel registered
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.TRAINING
|
|
| Mode.TORCH_COMPILE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
)
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.INFERENCE)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
# INFERENCE falls back to TRAINING | TORCH_COMPILE kernel
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# INFERENCE | TORCH_COMPILE falls back to TRAINING | TORCH_COMPILE kernel
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
linear(X)
|
|
# TRAINING falls back to TRAINING | TORCH_COMPILE kernel
|
|
assert linear.n_calls == 0
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# TRAINING | TORCH_COMPILE uses the kernel directly
|
|
assert linear.n_calls == 0
|
|
|
|
# Case 3: Test that training modes don't fall back to inference modes
|
|
with use_kernel_mapping(
|
|
{
|
|
"Linear": {
|
|
"cuda": {
|
|
Mode.INFERENCE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
Mode.INFERENCE
|
|
| Mode.TORCH_COMPILE: LayerRepository(
|
|
repo_id="kernels-test/backward-marker-test",
|
|
layer_name="LinearBackward",
|
|
),
|
|
}
|
|
}
|
|
}
|
|
):
|
|
kernelize(linear, mode=Mode.TRAINING)
|
|
X = torch.randn(10, 32, device="cuda")
|
|
linear(X)
|
|
# TRAINING should NOT fall back to inference kernels, use original
|
|
assert linear.n_calls == 1
|
|
|
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
linear(X)
|
|
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
|
|
assert linear.n_calls == 2
|
|
|
|
|
|
def test_layer_versions():
|
|
@use_kernel_forward_from_hub("Version")
|
|
class Version(nn.Module):
|
|
def forward(self) -> str:
|
|
return "0.0.0"
|
|
|
|
version = Version()
|
|
|
|
with use_kernel_mapping(
|
|
{
|
|
"Version": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/versions",
|
|
layer_name="Version",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
|
assert version() == "0.2.0"
|
|
|
|
with use_kernel_mapping(
|
|
{
|
|
"Version": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/versions",
|
|
layer_name="Version",
|
|
version="<1.0.0",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
|
assert version() == "0.2.0"
|
|
|
|
with use_kernel_mapping(
|
|
{
|
|
"Version": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/versions",
|
|
layer_name="Version",
|
|
version="<0.2.0",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
|
assert version() == "0.1.1"
|
|
|
|
with use_kernel_mapping(
|
|
{
|
|
"Version": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/versions",
|
|
layer_name="Version",
|
|
version=">0.1.0,<0.2.0",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
|
assert version() == "0.1.1"
|
|
|
|
with use_kernel_mapping(
|
|
{
|
|
"Version": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/versions",
|
|
layer_name="Version",
|
|
version=">0.2.0",
|
|
)
|
|
}
|
|
}
|
|
):
|
|
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
|
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
|
|
|
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
|
use_kernel_mapping(
|
|
{
|
|
"Version": {
|
|
Device(type="cuda"): LayerRepository(
|
|
repo_id="kernels-test/versions",
|
|
layer_name="Version",
|
|
revision="v0.1.0",
|
|
version="<1.0.0",
|
|
)
|
|
}
|
|
}
|
|
)
|