Files
kernels/tests/test_benchmarks.py
Daniël de Kok f7490bd0a9 Test examples in docstrings using mktestdocs (#118)
Also adjust examples so that they are correct.
2025-07-28 17:31:34 +02:00

38 lines
878 B
Python

import pytest
import torch
from kernels import get_kernel
@pytest.fixture
def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def device():
if not torch.cuda.is_available():
pytest.skip("No CUDA")
return "cuda"
@pytest.mark.cuda_only
def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.cuda_only
def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.cuda_only
def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)