Compare commits

..

4 Commits
ruff ... v0.2.0

Author SHA1 Message Date
18ecd0ce69 Set version to 0.2.0 (#41) 2025-03-10 10:24:02 +01:00
b4ef1d60e5 Update torch dependency to 2.5 (#40)
Fixes #37.
2025-03-07 20:32:54 +01:00
a40756f306 Configure ruff lints and add to CI (#39) 2025-03-07 20:32:44 +01:00
3671158f47 Rename noarch to universal (#38)
Also update docs to mention this variant.
2025-03-07 15:12:44 +01:00
6 changed files with 49 additions and 44 deletions

View File

@ -26,8 +26,13 @@ recommended build variants are:
- `torch26-cxx98-cu124-x86_64-linux`
- `torch26-cxx98-cu126-x86_64-linux`
This list will be updated as new PyTorch versions are released. Each
variant directory should contain a single directory with the same name
This list will be updated as new PyTorch versions are released. Kernels
that are in pure Python (e.g. Triton kernels) only need to provide a
single build variant:
- `torch-universal`
Each variant directory should contain a single directory with the same name
as the repository (replacing `-` by `_`). For instance, kernels in the
`kernels-community/activation` repository have a directories like
`build/<variant>/activation`. This directory

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.1.7"
version = "0.2.0"
description = "Download cuda kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -14,7 +14,7 @@ dependencies = [
"huggingface-hub>=0.26.3",
"packaging>=24.2",
"tomli>=2.0.1; python_version<'3.11'",
"torch>=2.4",
"torch>=2.5",
]
[build-system]
@ -38,24 +38,24 @@ kernels = "kernels.cli:main"
[tool.ruff]
exclude = [
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".venv*",
"__pypackages__",
"_build",
"build",
"dist",
"venv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".venv*",
"__pypackages__",
"_build",
"build",
"dist",
"venv",
]
line-length = 119
# Ignored rules:

View File

@ -37,9 +37,9 @@ def build_variant() -> str:
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
def noarch_build_variant() -> str:
def universal_build_variant() -> str:
# Once we support other frameworks, detection goes here.
return "torch-noarch"
return "torch-universal"
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
@ -73,11 +73,11 @@ def install_kernel(
"""
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
noarch_variant = noarch_build_variant()
universal_variant = universal_build_variant()
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
@ -85,12 +85,12 @@ def install_kernel(
)
variant_path = repo_path / "build" / variant
noarch_variant_path = repo_path / "build" / noarch_variant
universal_variant_path = repo_path / "build" / universal_variant
if not variant_path.exists() and noarch_variant_path.exists():
# Fall back to noarch variant.
variant = noarch_variant
variant_path = noarch_variant_path
if not variant_path.exists() and universal_variant_path.exists():
# Fall back to universal variant.
variant = universal_variant
variant_path = universal_variant_path
if variant_locks is not None:
variant_lock = variant_locks.get(variant)
@ -156,23 +156,23 @@ def load_kernel(repo_id: str) -> ModuleType:
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
noarch_variant = noarch_build_variant()
universal_variant = universal_build_variant()
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
cache_dir=CACHE_DIR,
local_files_only=True,
)
)
variant_path = repo_path / "build" / variant
noarch_variant_path = repo_path / "build" / noarch_variant
if not variant_path.exists() and noarch_variant_path.exists():
# Fall back to noarch variant.
variant = noarch_variant
variant_path = noarch_variant_path
universal_variant_path = repo_path / "build" / universal_variant
if not variant_path.exists() and universal_variant_path.exists():
# Fall back to universal variant.
variant = universal_variant
variant_path = universal_variant_path
module_init_path = variant_path / package_name / "__init__.py"
if not os.path.exists(module_init_path):

View File

@ -55,9 +55,9 @@
},
{
"repo_id": "kernels-community/triton-scaled-mm",
"sha": "9baccbeb763fe5f1b8fbdb9c1e5699548c46632c",
"sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f",
"variants": {
"torch-noarch": {
"torch-universal": {
"hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52",
"hash_type": "git_lfs_concat"
}

View File

@ -1,3 +1,3 @@
[tool.kernels.dependencies]
"kernels-community/activation" = ">=0.0.2"
"kernels-community/triton-scaled-mm" = ">=0.0.1"
"kernels-community/triton-scaled-mm" = ">=0.0.2"

View File

@ -10,7 +10,7 @@ def kernel():
@pytest.fixture
def noarch_kernel():
def universal_kernel():
return get_kernel("kernels-community/triton-scaled-mm")
@ -36,14 +36,14 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
def test_noarch_kernel(noarch_kernel):
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda")
scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda")
scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda")
out = noarch_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
out_check = (A * scale_a) @ (B * scale_b)
out_check = out_check.to(torch.float16)