mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
Add support for Metal builds (#89)
* Add support for Metal builds * Add Metal test, gate tests by OS where necessary
This commit is contained in:
4
pytest.ini
Normal file
4
pytest.ini
Normal file
@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
markers =
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
linux_only: marks tests that should only run on Linux
|
@ -43,14 +43,22 @@ def build_variant() -> str:
|
||||
elif torch.version.hip is not None:
|
||||
rocm_version = parse(torch.version.hip.split("-")[0])
|
||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||
elif torch.backends.mps.is_available():
|
||||
compute_framework = "metal"
|
||||
else:
|
||||
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
|
||||
raise AssertionError(
|
||||
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
|
||||
)
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
cpu = platform.machine()
|
||||
os = platform.system().lower()
|
||||
|
||||
if os == "darwin":
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
|
||||
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
|
||||
|
||||
|
||||
|
10
tests/conftest.py
Normal file
10
tests/conftest.py
Normal file
@ -0,0 +1,10 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
|
||||
pytest.skip("skipping Linux-only test on non-Linux platform")
|
||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
@ -9,6 +9,11 @@ def kernel():
|
||||
return get_kernel("kernels-community/activation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metal_kernel():
|
||||
return get_kernel("kernels-test/relu-metal")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def universal_kernel():
|
||||
return get_kernel("kernels-community/triton-scaled-mm")
|
||||
@ -21,6 +26,7 @@ def device():
|
||||
return "cuda"
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_gelu_fast(kernel, device):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
@ -36,6 +42,15 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.darwin_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_relu_metal(metal_kernel, dtype):
|
||||
x = torch.arange(-10, 10, dtype=dtype, device="mps")
|
||||
y = metal_kernel.relu(x)
|
||||
assert torch.allclose(y, torch.relu(x))
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_exists",
|
||||
[
|
||||
@ -52,6 +67,7 @@ def test_has_kernel(kernel_exists):
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||
|
@ -16,18 +16,21 @@ def device():
|
||||
return "cuda"
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_gelu_small(kernel, device, benchmark):
|
||||
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
benchmark(kernel.gelu_fast, y, x)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_gelu_medium(kernel, device, benchmark):
|
||||
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
benchmark(kernel.gelu_fast, y, x)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_gelu_large(kernel, device, benchmark):
|
||||
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
|
@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from kernels import load_kernel
|
||||
from kernels.cli import download_kernels
|
||||
|
||||
@ -17,6 +19,7 @@ def test_download_all_hash_validation():
|
||||
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_load_locked():
|
||||
project_dir = Path(__file__).parent / "kernel_locking"
|
||||
# Also validates that hashing works correctly.
|
||||
|
@ -82,6 +82,7 @@ def test_arg_kinds():
|
||||
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_hub_forward(cls, device):
|
||||
@ -112,6 +113,7 @@ def test_layer_fallback_works():
|
||||
SiluAndMulWithKernelFallback()
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_torch_compile_layer(cls, device):
|
||||
@ -242,6 +244,7 @@ def test_validate_kernel_layer():
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_fallback_used_when_training():
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinear(nn.Linear):
|
||||
|
Reference in New Issue
Block a user