mirror of
https://github.com/huggingface/kernels.git
synced 2025-11-02 04:34:27 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 18ecd0ce69 | |||
| b4ef1d60e5 | |||
| a40756f306 | |||
| 3671158f47 |
@ -26,8 +26,13 @@ recommended build variants are:
|
||||
- `torch26-cxx98-cu124-x86_64-linux`
|
||||
- `torch26-cxx98-cu126-x86_64-linux`
|
||||
|
||||
This list will be updated as new PyTorch versions are released. Each
|
||||
variant directory should contain a single directory with the same name
|
||||
This list will be updated as new PyTorch versions are released. Kernels
|
||||
that are in pure Python (e.g. Triton kernels) only need to provide a
|
||||
single build variant:
|
||||
|
||||
- `torch-universal`
|
||||
|
||||
Each variant directory should contain a single directory with the same name
|
||||
as the repository (replacing `-` by `_`). For instance, kernels in the
|
||||
`kernels-community/activation` repository have a directories like
|
||||
`build/<variant>/activation`. This directory
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.1.7"
|
||||
version = "0.2.0"
|
||||
description = "Download cuda kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
@ -14,7 +14,7 @@ dependencies = [
|
||||
"huggingface-hub>=0.26.3",
|
||||
"packaging>=24.2",
|
||||
"tomli>=2.0.1; python_version<'3.11'",
|
||||
"torch>=2.4",
|
||||
"torch>=2.5",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@ -38,24 +38,24 @@ kernels = "kernels.cli:main"
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".venv*",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"build",
|
||||
"dist",
|
||||
"venv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".venv*",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"build",
|
||||
"dist",
|
||||
"venv",
|
||||
]
|
||||
line-length = 119
|
||||
# Ignored rules:
|
||||
|
||||
@ -37,9 +37,9 @@ def build_variant() -> str:
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
|
||||
|
||||
|
||||
def noarch_build_variant() -> str:
|
||||
def universal_build_variant() -> str:
|
||||
# Once we support other frameworks, detection goes here.
|
||||
return "torch-noarch"
|
||||
return "torch-universal"
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
|
||||
@ -73,11 +73,11 @@ def install_kernel(
|
||||
"""
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
noarch_variant = noarch_build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
repo_path = Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
@ -85,12 +85,12 @@ def install_kernel(
|
||||
)
|
||||
|
||||
variant_path = repo_path / "build" / variant
|
||||
noarch_variant_path = repo_path / "build" / noarch_variant
|
||||
universal_variant_path = repo_path / "build" / universal_variant
|
||||
|
||||
if not variant_path.exists() and noarch_variant_path.exists():
|
||||
# Fall back to noarch variant.
|
||||
variant = noarch_variant
|
||||
variant_path = noarch_variant_path
|
||||
if not variant_path.exists() and universal_variant_path.exists():
|
||||
# Fall back to universal variant.
|
||||
variant = universal_variant
|
||||
variant_path = universal_variant_path
|
||||
|
||||
if variant_locks is not None:
|
||||
variant_lock = variant_locks.get(variant)
|
||||
@ -156,23 +156,23 @@ def load_kernel(repo_id: str) -> ModuleType:
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
|
||||
variant = build_variant()
|
||||
noarch_variant = noarch_build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
repo_path = Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
)
|
||||
|
||||
variant_path = repo_path / "build" / variant
|
||||
noarch_variant_path = repo_path / "build" / noarch_variant
|
||||
if not variant_path.exists() and noarch_variant_path.exists():
|
||||
# Fall back to noarch variant.
|
||||
variant = noarch_variant
|
||||
variant_path = noarch_variant_path
|
||||
universal_variant_path = repo_path / "build" / universal_variant
|
||||
if not variant_path.exists() and universal_variant_path.exists():
|
||||
# Fall back to universal variant.
|
||||
variant = universal_variant
|
||||
variant_path = universal_variant_path
|
||||
|
||||
module_init_path = variant_path / package_name / "__init__.py"
|
||||
if not os.path.exists(module_init_path):
|
||||
|
||||
@ -55,9 +55,9 @@
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/triton-scaled-mm",
|
||||
"sha": "9baccbeb763fe5f1b8fbdb9c1e5699548c46632c",
|
||||
"sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f",
|
||||
"variants": {
|
||||
"torch-noarch": {
|
||||
"torch-universal": {
|
||||
"hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-community/activation" = ">=0.0.2"
|
||||
"kernels-community/triton-scaled-mm" = ">=0.0.1"
|
||||
"kernels-community/triton-scaled-mm" = ">=0.0.2"
|
||||
|
||||
@ -10,7 +10,7 @@ def kernel():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noarch_kernel():
|
||||
def universal_kernel():
|
||||
return get_kernel("kernels-community/triton-scaled-mm")
|
||||
|
||||
|
||||
@ -36,14 +36,14 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
def test_noarch_kernel(noarch_kernel):
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||
B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda")
|
||||
scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda")
|
||||
scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda")
|
||||
|
||||
out = noarch_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
|
||||
out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
|
||||
out_check = (A * scale_a) @ (B * scale_b)
|
||||
out_check = out_check.to(torch.float16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user