mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-11-04 14:14:31 +08:00 
			
		
		
		
	Compare commits
	
		
			25 Commits
		
	
	
		
			faq-kernel
			...
			v0.11.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a775b72cad | |||
| 10753bdcb9 | |||
| 313e883a04 | |||
| f32e3cd564 | |||
| 75670113f6 | |||
| 5d21b86a5d | |||
| cc754ff93e | |||
| 46a2a1fd6d | |||
| ed048616fe | |||
| b182cd3458 | |||
| ce77658efc | |||
| b96b154e7f | |||
| b24ef9fa6b | |||
| a7101b2cfd | |||
| 6241afa06e | |||
| 34a1932751 | |||
| e39eac09c1 | |||
| b0c431fee4 | |||
| 9a188eadbe | |||
| 457c7c1b8d | |||
| fb8cd99a2c | |||
| dfee307d54 | |||
| 93e5765611 | |||
| bf488208be | |||
| 2a14472e4c | 
							
								
								
									
										16
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							@ -24,7 +24,7 @@ jobs:
 | 
			
		||||
      max-parallel: 4
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: ["3.10", "3.12"]
 | 
			
		||||
        torch-version: ["2.6.0", "2.7.0"]
 | 
			
		||||
        torch-version: ["2.7.0", "2.8.0"]
 | 
			
		||||
 | 
			
		||||
    env:
 | 
			
		||||
      UV_PYTHON_PREFERENCE: only-managed
 | 
			
		||||
@ -51,11 +51,16 @@ jobs:
 | 
			
		||||
        run: uv run mypy src/kernels
 | 
			
		||||
 | 
			
		||||
      - name: Run tests
 | 
			
		||||
        env:
 | 
			
		||||
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
 | 
			
		||||
        run: |
 | 
			
		||||
          uv run pytest tests
 | 
			
		||||
 | 
			
		||||
      - name: Run staging tests
 | 
			
		||||
        env:
 | 
			
		||||
          HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}
 | 
			
		||||
        run: |
 | 
			
		||||
          HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/
 | 
			
		||||
        if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0'
 | 
			
		||||
 | 
			
		||||
      - name: Check kernel conversion
 | 
			
		||||
        run: |
 | 
			
		||||
          uv pip install wheel
 | 
			
		||||
@ -68,6 +73,11 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          uv run kernels generate-readme kernels-community/triton-layer-norm
 | 
			
		||||
 | 
			
		||||
      - name: Check kernel check
 | 
			
		||||
        run: |
 | 
			
		||||
          uv pip install kernel-abi-check
 | 
			
		||||
          kernels check kernels-community/activation
 | 
			
		||||
 | 
			
		||||
      - name: Import check without torch
 | 
			
		||||
        run: |
 | 
			
		||||
          uv pip uninstall torch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
.PHONY: style
 | 
			
		||||
 | 
			
		||||
export check_dirs := src examples tests
 | 
			
		||||
 | 
			
		||||
style:
 | 
			
		||||
	black ${check_dirs}
 | 
			
		||||
	isort ${check_dirs}
 | 
			
		||||
	ruff check ${check_dirs} --fix
 | 
			
		||||
@ -51,7 +51,7 @@ activation.gelu_fast(y, x)
 | 
			
		||||
print(y)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
You can [search for kernels](https://huggingface.co/models?other=kernel) on
 | 
			
		||||
You can [search for kernels](https://huggingface.co/models?other=kernels) on
 | 
			
		||||
the Hub.
 | 
			
		||||
 | 
			
		||||
## 📚 Documentation
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,10 @@
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.get_kernel
 | 
			
		||||
 | 
			
		||||
### get_local_kernel
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.get_local_kernel
 | 
			
		||||
 | 
			
		||||
### has_kernel
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.has_kernel
 | 
			
		||||
 | 
			
		||||
@ -39,3 +39,11 @@
 | 
			
		||||
### LayerRepository
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.LayerRepository
 | 
			
		||||
 | 
			
		||||
### LocalLayerRepository
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.LocalLayerRepository
 | 
			
		||||
 | 
			
		||||
### LockedLayerRepository
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.LockedLayerRepository
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,24 @@
 | 
			
		||||
 | 
			
		||||
## Main Functions
 | 
			
		||||
 | 
			
		||||
### kernels check
 | 
			
		||||
 | 
			
		||||
You can use `kernels check` to test compliance of a kernel on the Hub.
 | 
			
		||||
This currently checks that the kernel:
 | 
			
		||||
 | 
			
		||||
- Supports the currently-required Python ABI version.
 | 
			
		||||
- Works on supported operating system versions.
 | 
			
		||||
 | 
			
		||||
For example:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
$ kernels check kernels-community/flash-attn3
 | 
			
		||||
Checking variant: torch28-cxx11-cu128-aarch64-linux
 | 
			
		||||
  🐍 Python ABI 3.9 compatible
 | 
			
		||||
  🐧 manylinux_2_28 compatible
 | 
			
		||||
[...]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### kernels to-wheel
 | 
			
		||||
 | 
			
		||||
We strongly recommend downloading kernels from the Hub using the `kernels`
 | 
			
		||||
@ -30,7 +48,7 @@ $ kernels to-wheel drbh/img2grey 1.1.2
 | 
			
		||||
### kernels upload
 | 
			
		||||
 | 
			
		||||
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
 | 
			
		||||
your kernel builds to the Hub.
 | 
			
		||||
your kernel builds to the Hub. To know the supported arguments run: `kernels upload -h`.
 | 
			
		||||
 | 
			
		||||
**Notes**:
 | 
			
		||||
 | 
			
		||||
@ -38,4 +56,3 @@ your kernel builds to the Hub.
 | 
			
		||||
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
 | 
			
		||||
  being uploaded, it will attempt to delete the files existing under it.
 | 
			
		||||
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -39,3 +39,13 @@ The approach of `forward`-replacement is the least invasive, because
 | 
			
		||||
it preserves the original model graph. It is also reversible, since
 | 
			
		||||
even though the `forward` of a layer _instance_ might be replaced,
 | 
			
		||||
the corresponding class still has the original `forward`.
 | 
			
		||||
 | 
			
		||||
## Misc
 | 
			
		||||
 | 
			
		||||
### How can I disable kernel reporting in the user-agent?
 | 
			
		||||
 | 
			
		||||
By default, we collect telemetry when a call to `get_kernel()` is made.
 | 
			
		||||
This only includes the `kernels` version, `torch` version, and the build
 | 
			
		||||
information for the kernel being requested.
 | 
			
		||||
 | 
			
		||||
You can disable this by setting `export DISABLE_TELEMETRY=yes`.
 | 
			
		||||
 | 
			
		||||
@ -16,5 +16,5 @@ packages in that they are made to be:
 | 
			
		||||
  the different PyTorch build configurations (various CUDA versions
 | 
			
		||||
  and C++ ABIs). Furthermore, older C library versions must be supported.
 | 
			
		||||
 | 
			
		||||
You can [search for kernels](https://huggingface.co/models?other=kernel) on
 | 
			
		||||
You can [search for kernels](https://huggingface.co/models?other=kernels) on
 | 
			
		||||
the Hub.
 | 
			
		||||
 | 
			
		||||
@ -46,6 +46,16 @@ have dynamic library dependencies outside:
 | 
			
		||||
- Torch;
 | 
			
		||||
- CUDA/ROCm libraries installed as dependencies of Torch.
 | 
			
		||||
 | 
			
		||||
## Compatibility with torch.compile
 | 
			
		||||
 | 
			
		||||
The Kernel Hub also encourages to write the kernels in a `torch.compile`
 | 
			
		||||
compliant way. This helps to ensure that the kernels are compatible with
 | 
			
		||||
`torch.compile` without introducing any graph breaks and triggering 
 | 
			
		||||
recompilation which can limit the benefits of compilation.
 | 
			
		||||
 | 
			
		||||
[Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and 
 | 
			
		||||
recompilation triggers during `torch.compile`.
 | 
			
		||||
 | 
			
		||||
### Linux
 | 
			
		||||
 | 
			
		||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
 | 
			
		||||
 | 
			
		||||
@ -20,11 +20,11 @@ activation.gelu_fast(y, x)
 | 
			
		||||
print("Kernel successfully executed")
 | 
			
		||||
 | 
			
		||||
# Check results
 | 
			
		||||
expected = torch.tensor([
 | 
			
		||||
    [0.8408, 1.9551, 2.9961],
 | 
			
		||||
    [4.0000, 5.0000, 6.0000],
 | 
			
		||||
    [7.0000, 8.0000, 9.0000]
 | 
			
		||||
], device='cuda:0', dtype=torch.float16)
 | 
			
		||||
expected = torch.tensor(
 | 
			
		||||
    [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
 | 
			
		||||
    device="cuda:0",
 | 
			
		||||
    dtype=torch.float16,
 | 
			
		||||
)
 | 
			
		||||
assert torch.allclose(y, expected)
 | 
			
		||||
 | 
			
		||||
print("Calculated values are exact")
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@
 | 
			
		||||
      in
 | 
			
		||||
      {
 | 
			
		||||
        formatter = pkgs.nixfmt-tree;
 | 
			
		||||
        packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {};
 | 
			
		||||
        devShells = with pkgs; rec {
 | 
			
		||||
          default = mkShell {
 | 
			
		||||
            nativeBuildInputs = [
 | 
			
		||||
@ -40,6 +41,7 @@
 | 
			
		||||
              ++ (with python3.pkgs; [
 | 
			
		||||
                docutils
 | 
			
		||||
                huggingface-hub
 | 
			
		||||
                (callPackage ./nix/kernel-abi-check.nix {})
 | 
			
		||||
                mktestdocs
 | 
			
		||||
                pytest
 | 
			
		||||
                pytest-benchmark
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
{
 | 
			
		||||
  buildPythonPackage,
 | 
			
		||||
  fetchPypi,
 | 
			
		||||
  rustPlatform,
 | 
			
		||||
}:
 | 
			
		||||
 | 
			
		||||
buildPythonPackage rec {
 | 
			
		||||
  pname = "kernel-abi-check";
 | 
			
		||||
  version = "0.6.2";
 | 
			
		||||
 | 
			
		||||
  src = fetchPypi {
 | 
			
		||||
    inherit version;
 | 
			
		||||
    pname = "kernel_abi_check";
 | 
			
		||||
    hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys=";
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  cargoDeps = rustPlatform.fetchCargoVendor {
 | 
			
		||||
    inherit pname version src sourceRoot;
 | 
			
		||||
    hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A=";
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  sourceRoot = "kernel_abi_check-${version}/bindings/python";
 | 
			
		||||
 | 
			
		||||
  pyproject = true;
 | 
			
		||||
 | 
			
		||||
  nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ];
 | 
			
		||||
}
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
[project]
 | 
			
		||||
name = "kernels"
 | 
			
		||||
version = "0.10.1.dev0"
 | 
			
		||||
version = "0.11.0"
 | 
			
		||||
description = "Download compute kernels"
 | 
			
		||||
authors = [
 | 
			
		||||
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
 | 
			
		||||
@ -12,7 +12,7 @@ license = { text = "Apache-2.0" }
 | 
			
		||||
readme = "README.md"
 | 
			
		||||
requires-python = ">= 3.9"
 | 
			
		||||
dependencies = [
 | 
			
		||||
  "huggingface_hub>=0.26.0,<1.0",
 | 
			
		||||
  "huggingface_hub>=0.26.0,<2.0",
 | 
			
		||||
  "packaging>=20.0",
 | 
			
		||||
  "pyyaml>=6",
 | 
			
		||||
  "tomli>=2.0; python_version<'3.11'",
 | 
			
		||||
@ -34,6 +34,7 @@ dev = [
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[project.optional-dependencies]
 | 
			
		||||
abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
 | 
			
		||||
torch = ["torch"]
 | 
			
		||||
docs = [
 | 
			
		||||
  "hf-doc-builder",
 | 
			
		||||
@ -45,6 +46,9 @@ kernels = "kernels.cli:main"
 | 
			
		||||
[project.entry-points."egg_info.writers"]
 | 
			
		||||
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
 | 
			
		||||
 | 
			
		||||
[tool.isort]
 | 
			
		||||
profile = "black"
 | 
			
		||||
line_length = 119
 | 
			
		||||
 | 
			
		||||
[tool.ruff]
 | 
			
		||||
exclude = [
 | 
			
		||||
@ -71,4 +75,4 @@ line-length = 119
 | 
			
		||||
# Ignored rules:
 | 
			
		||||
# "E501" -> line length violation
 | 
			
		||||
lint.ignore = ["E501"]
 | 
			
		||||
lint.select = ["E", "F", "I", "W"]
 | 
			
		||||
lint.select = ["E", "F", "W"]
 | 
			
		||||
 | 
			
		||||
@ -4,4 +4,6 @@ markers =
 | 
			
		||||
    rocm_only: marks tests that should only run on hosts with ROCm GPUs
 | 
			
		||||
    darwin_only: marks tests that should only run on macOS
 | 
			
		||||
    xpu_only: marks tests that should only run on hosts with Intel XPUs
 | 
			
		||||
    npu_only: marks tests that should only run on Ascend NPUs
 | 
			
		||||
    token: enable tests that require a write token
 | 
			
		||||
    is_staging_test: Marks tests that should only run on a staging environment
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										142
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,142 @@
 | 
			
		||||
import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from kernel_abi_check import (
 | 
			
		||||
    BinaryFormat,
 | 
			
		||||
    IncompatibleAbi3Symbol,
 | 
			
		||||
    IncompatibleMacOSVersion,
 | 
			
		||||
    IncompatibleManylinuxSymbol,
 | 
			
		||||
    MissingMacOSVersion,
 | 
			
		||||
    NonAbi3Symbol,
 | 
			
		||||
    ObjectFile,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from kernels.utils import CACHE_DIR
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_kernel(
 | 
			
		||||
    *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
 | 
			
		||||
):
 | 
			
		||||
    variants_path = (
 | 
			
		||||
        Path(
 | 
			
		||||
            snapshot_download(
 | 
			
		||||
                repo_id,
 | 
			
		||||
                allow_patterns=["build/*"],
 | 
			
		||||
                cache_dir=CACHE_DIR,
 | 
			
		||||
                revision=revision,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        / "build"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    for variant_path in variants_path.iterdir():
 | 
			
		||||
        if not variant_path.is_dir():
 | 
			
		||||
            print(
 | 
			
		||||
                f"⛔ `build/` must only contain directories, found: {variant_path.name}",
 | 
			
		||||
                file=sys.stderr,
 | 
			
		||||
            )
 | 
			
		||||
            has_issues = True
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        print(f"Checking variant: {variant_path.name}", file=sys.stderr)
 | 
			
		||||
 | 
			
		||||
        indent = 2
 | 
			
		||||
 | 
			
		||||
        for dylib_path in variant_path.rglob("*.so"):
 | 
			
		||||
            print_with_indent(
 | 
			
		||||
                indent,
 | 
			
		||||
                f"Dynamic library {dylib_path.relative_to(variant_path)}:",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            o = ObjectFile(dylib_path)
 | 
			
		||||
            has_issues |= check_abi3(o, python_abi, indent + 2)
 | 
			
		||||
 | 
			
		||||
            # TODO: also check operating system
 | 
			
		||||
            if o.format() == BinaryFormat.ELF:
 | 
			
		||||
                has_issues |= check_manylinux(o, manylinux, indent + 2)
 | 
			
		||||
            elif o.format() == BinaryFormat.MACH_O:
 | 
			
		||||
                has_issues |= check_macos(o, macos, indent + 2)
 | 
			
		||||
 | 
			
		||||
    if has_issues:
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool:
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    violations = object_file.check_python_abi(python_abi)
 | 
			
		||||
    if violations != []:
 | 
			
		||||
        has_issues = True
 | 
			
		||||
        print_with_indent(
 | 
			
		||||
            indent,
 | 
			
		||||
            f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:",
 | 
			
		||||
        )
 | 
			
		||||
        for violation in violations:
 | 
			
		||||
            if isinstance(violation, IncompatibleAbi3Symbol):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"{violation.name}: {violation.version_added}",
 | 
			
		||||
                )
 | 
			
		||||
            elif isinstance(violation, NonAbi3Symbol):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"{violation.name}",
 | 
			
		||||
                )
 | 
			
		||||
    else:
 | 
			
		||||
        print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible")
 | 
			
		||||
 | 
			
		||||
    return has_issues
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool:
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    violations = object_file.check_macos(macos)
 | 
			
		||||
    if violations != []:
 | 
			
		||||
        has_issues = True
 | 
			
		||||
        print_with_indent(
 | 
			
		||||
            indent,
 | 
			
		||||
            f"⛔ Found incompatibility with macOS {macos}:",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for violation in violations:
 | 
			
		||||
            if isinstance(violation, MissingMacOSVersion):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    "shared library does not contain macOS version",
 | 
			
		||||
                )
 | 
			
		||||
            elif isinstance(violation, IncompatibleMacOSVersion):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"shared library requires macOS {violation.version}",
 | 
			
		||||
                )
 | 
			
		||||
    else:
 | 
			
		||||
        print_with_indent(indent, f"🍏 compatible with macOS {macos}")
 | 
			
		||||
 | 
			
		||||
    return has_issues
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool:
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    violations = object_file.check_manylinux(manylinux)
 | 
			
		||||
    if violations != []:
 | 
			
		||||
        has_issues = True
 | 
			
		||||
        print_with_indent(
 | 
			
		||||
            indent,
 | 
			
		||||
            f"⛔ Found symbols that are incompatible with {manylinux}:",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for violation in violations:
 | 
			
		||||
            if isinstance(violation, IncompatibleManylinuxSymbol):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"{violation.name}_{violation.dep}: {violation.version}",
 | 
			
		||||
                )
 | 
			
		||||
    else:
 | 
			
		||||
        print_with_indent(indent, f"🐧 {manylinux} compatible")
 | 
			
		||||
 | 
			
		||||
    return has_issues
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_with_indent(indent: int, message: str):
 | 
			
		||||
    print(f"{' ' * indent}{message}", file=sys.stderr)
 | 
			
		||||
@ -1,10 +1,11 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import dataclasses
 | 
			
		||||
import json
 | 
			
		||||
import re
 | 
			
		||||
import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import create_repo, upload_folder
 | 
			
		||||
from huggingface_hub import create_repo, upload_folder, create_branch
 | 
			
		||||
 | 
			
		||||
from kernels.compat import tomllib
 | 
			
		||||
from kernels.lockfile import KernelLock, get_kernel_locks
 | 
			
		||||
@ -13,6 +14,8 @@ from kernels.utils import install_kernel, install_kernel_all_variants
 | 
			
		||||
from .doc import generate_readme_for_kernel
 | 
			
		||||
from .wheel import build_variant_to_wheel
 | 
			
		||||
 | 
			
		||||
BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-universal)")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
@ -20,6 +23,31 @@ def main():
 | 
			
		||||
    )
 | 
			
		||||
    subparsers = parser.add_subparsers(required=True)
 | 
			
		||||
 | 
			
		||||
    check_parser = subparsers.add_parser("check", help="Check a kernel for compliance")
 | 
			
		||||
    check_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
 | 
			
		||||
    check_parser.add_argument(
 | 
			
		||||
        "--revision",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="main",
 | 
			
		||||
        help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
 | 
			
		||||
    )
 | 
			
		||||
    check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0")
 | 
			
		||||
    check_parser.add_argument(
 | 
			
		||||
        "--manylinux", type=str, help="Manylinux version", default="manylinux_2_28"
 | 
			
		||||
    )
 | 
			
		||||
    check_parser.add_argument(
 | 
			
		||||
        "--python-abi", type=str, help="Python ABI version", default="3.9"
 | 
			
		||||
    )
 | 
			
		||||
    check_parser.set_defaults(
 | 
			
		||||
        func=lambda args: check_kernel(
 | 
			
		||||
            macos=args.macos,
 | 
			
		||||
            manylinux=args.manylinux,
 | 
			
		||||
            python_abi=args.python_abi,
 | 
			
		||||
            repo_id=args.repo_id,
 | 
			
		||||
            revision=args.revision,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    download_parser = subparsers.add_parser("download", help="Download locked kernels")
 | 
			
		||||
    download_parser.add_argument(
 | 
			
		||||
        "project_dir",
 | 
			
		||||
@ -40,10 +68,15 @@ def main():
 | 
			
		||||
        help="Directory of the kernel build",
 | 
			
		||||
    )
 | 
			
		||||
    upload_parser.add_argument(
 | 
			
		||||
        "--repo_id",
 | 
			
		||||
        "--repo-id",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Repository ID to use to upload to the Hugging Face Hub",
 | 
			
		||||
    )
 | 
			
		||||
    upload_parser.add_argument(
 | 
			
		||||
        "--branch",
 | 
			
		||||
        type=None,
 | 
			
		||||
        help="If set, the upload will be made to a particular branch of the provided `repo-id`.",
 | 
			
		||||
    )
 | 
			
		||||
    upload_parser.add_argument(
 | 
			
		||||
        "--private",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
@ -174,17 +207,31 @@ def lock_kernels(args):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upload_kernels(args):
 | 
			
		||||
    # Resolve `kernel_dir` to be uploaded.
 | 
			
		||||
    kernel_dir = Path(args.kernel_dir).resolve()
 | 
			
		||||
    build_dir = kernel_dir / "build"
 | 
			
		||||
    if not kernel_dir.is_dir():
 | 
			
		||||
        raise ValueError(f"{kernel_dir} is not a directory")
 | 
			
		||||
    if not build_dir.is_dir():
 | 
			
		||||
        raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
 | 
			
		||||
 | 
			
		||||
    build_dir = None
 | 
			
		||||
    for candidate in [kernel_dir / "build", kernel_dir]:
 | 
			
		||||
        variants = [
 | 
			
		||||
            variant_path
 | 
			
		||||
            for variant_path in candidate.glob("torch*")
 | 
			
		||||
            if BUILD_VARIANT_REGEX.match(variant_path.name) is not None
 | 
			
		||||
        ]
 | 
			
		||||
        if variants:
 | 
			
		||||
            build_dir = candidate
 | 
			
		||||
            break
 | 
			
		||||
    if build_dir is None:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Couldn't find any build variants in: {kernel_dir.absolute()} or {(kernel_dir / 'build').absolute()}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    repo_id = create_repo(
 | 
			
		||||
        repo_id=args.repo_id, private=args.private, exist_ok=True
 | 
			
		||||
    ).repo_id
 | 
			
		||||
 | 
			
		||||
    if args.branch is not None:
 | 
			
		||||
        create_branch(repo_id=repo_id, branch=args.branch, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    delete_patterns: set[str] = set()
 | 
			
		||||
    for build_variant in build_dir.iterdir():
 | 
			
		||||
        if build_variant.is_dir():
 | 
			
		||||
@ -193,6 +240,7 @@ def upload_kernels(args):
 | 
			
		||||
    upload_folder(
 | 
			
		||||
        repo_id=repo_id,
 | 
			
		||||
        folder_path=build_dir,
 | 
			
		||||
        revision=args.branch,
 | 
			
		||||
        path_in_repo="build",
 | 
			
		||||
        delete_patterns=list(delete_patterns),
 | 
			
		||||
        commit_message="Build uploaded using `kernels`.",
 | 
			
		||||
@ -205,3 +253,24 @@ class _JSONEncoder(json.JSONEncoder):
 | 
			
		||||
        if dataclasses.is_dataclass(o):
 | 
			
		||||
            return dataclasses.asdict(o)
 | 
			
		||||
        return super().default(o)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_kernel(
 | 
			
		||||
    *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
 | 
			
		||||
):
 | 
			
		||||
    try:
 | 
			
		||||
        import kernels.check
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        print(
 | 
			
		||||
            "`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check",
 | 
			
		||||
            file=sys.stderr,
 | 
			
		||||
        )
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
    kernels.check.check_kernel(
 | 
			
		||||
        macos=macos,
 | 
			
		||||
        manylinux=manylinux,
 | 
			
		||||
        python_abi=python_abi,
 | 
			
		||||
        repo_id=repo_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -111,10 +111,10 @@ def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
 | 
			
		||||
def generate_metadata(module: ModuleType) -> None:
 | 
			
		||||
    metadata = getattr(module, "__kernel_metadata__", {})
 | 
			
		||||
    if "tags" not in metadata:
 | 
			
		||||
        metadata["tags"] = ["kernel"]
 | 
			
		||||
        metadata["tags"] = ["kernels"]
 | 
			
		||||
    else:
 | 
			
		||||
        if "kernel" not in metadata["tags"]:
 | 
			
		||||
            metadata["tags"].append("kernel")
 | 
			
		||||
        if "kernels" not in metadata["tags"]:
 | 
			
		||||
            metadata["tags"].append("kernels")
 | 
			
		||||
 | 
			
		||||
    print("---")
 | 
			
		||||
    print(yaml.dump(metadata), end="")
 | 
			
		||||
 | 
			
		||||
@ -87,7 +87,7 @@ class Device:
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        type (`str`):
 | 
			
		||||
            The device type (e.g., "cuda", "mps", "rocm", "xpu").
 | 
			
		||||
            The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu").
 | 
			
		||||
        properties ([`CUDAProperties`], *optional*):
 | 
			
		||||
            Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
 | 
			
		||||
 | 
			
		||||
@ -109,6 +109,9 @@ class Device:
 | 
			
		||||
 | 
			
		||||
        # XPU device (e.g., Intel(R) Data Center GPU Max 1550)
 | 
			
		||||
        xpu_device = Device(type="xpu")
 | 
			
		||||
 | 
			
		||||
        # NPU device (Huawei Ascend)
 | 
			
		||||
        npu_device = Device(type="npu")
 | 
			
		||||
        ```
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -130,6 +133,8 @@ class Device:
 | 
			
		||||
            return _MPSRepos()
 | 
			
		||||
        elif self.type == "xpu":
 | 
			
		||||
            return _XPURepos()
 | 
			
		||||
        elif self.type == "npu":
 | 
			
		||||
            return _NPURepos()
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Unknown device type: {self.type}")
 | 
			
		||||
 | 
			
		||||
@ -472,6 +477,26 @@ class _XPURepos(_DeviceRepos):
 | 
			
		||||
        self._repos = repos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _NPURepos(_DeviceRepos):
 | 
			
		||||
    _repos: Dict[Mode, LayerRepositoryProtocol]
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._repos = {}
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def repos(
 | 
			
		||||
        self,
 | 
			
		||||
    ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
 | 
			
		||||
        return self._repos
 | 
			
		||||
 | 
			
		||||
    def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
 | 
			
		||||
        if device.type != "npu":
 | 
			
		||||
            raise ValueError(f"Device type must be 'npu', got {device.type}")
 | 
			
		||||
 | 
			
		||||
        self._repos = repos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _MPSRepos(_DeviceRepos):
 | 
			
		||||
    _repos: Dict[Mode, LayerRepositoryProtocol]
 | 
			
		||||
 | 
			
		||||
@ -556,7 +581,7 @@ class _ROCMRepos(_DeviceRepos):
 | 
			
		||||
 | 
			
		||||
def _validate_device_type(device_type: str) -> None:
 | 
			
		||||
    """Validate that the device type is supported."""
 | 
			
		||||
    supported_devices = {"cuda", "rocm", "mps", "xpu"}
 | 
			
		||||
    supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"}
 | 
			
		||||
    if device_type not in supported_devices:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
 | 
			
		||||
@ -814,7 +839,7 @@ def kernelize(
 | 
			
		||||
            `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
 | 
			
		||||
            `torch.compile`.
 | 
			
		||||
        device (`Union[str, torch.device]`, *optional*):
 | 
			
		||||
            The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
 | 
			
		||||
            The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu".
 | 
			
		||||
            The device type will be inferred from the model parameters when not provided.
 | 
			
		||||
        use_fallback (`bool`, *optional*, defaults to `True`):
 | 
			
		||||
            Whether to use the original forward method of modules when no compatible kernel could be found.
 | 
			
		||||
@ -838,7 +863,7 @@ def kernelize(
 | 
			
		||||
                return F.silu(x[..., :d]) * x[..., d:]
 | 
			
		||||
 | 
			
		||||
        mapping = {
 | 
			
		||||
            "LayerNorm": {
 | 
			
		||||
            "SiluAndMul": {
 | 
			
		||||
                "cuda": LayerRepository(
 | 
			
		||||
                    repo_id="kernels-community/activation",
 | 
			
		||||
                    layer_name="SiluAndMul",
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ import sys
 | 
			
		||||
from importlib.metadata import Distribution
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from types import ModuleType
 | 
			
		||||
from typing import Dict, List, Optional, Tuple
 | 
			
		||||
from typing import Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import file_exists, snapshot_download
 | 
			
		||||
from packaging.version import parse
 | 
			
		||||
@ -19,6 +19,8 @@ from packaging.version import parse
 | 
			
		||||
from kernels._versions import select_revision_or_version
 | 
			
		||||
from kernels.lockfile import KernelLock, VariantLock
 | 
			
		||||
 | 
			
		||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_cache_dir() -> Optional[str]:
 | 
			
		||||
    """Returns the kernels cache directory."""
 | 
			
		||||
@ -35,6 +37,14 @@ def _get_cache_dir() -> Optional[str]:
 | 
			
		||||
CACHE_DIR: Optional[str] = _get_cache_dir()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_privateuse_backend_name() -> Optional[str]:
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    if hasattr(torch._C, "_get_privateuse1_backend_name"):
 | 
			
		||||
        return torch._C._get_privateuse1_backend_name()
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_variant() -> str:
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
@ -46,13 +56,16 @@ def build_variant() -> str:
 | 
			
		||||
        compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
 | 
			
		||||
    elif torch.backends.mps.is_available():
 | 
			
		||||
        compute_framework = "metal"
 | 
			
		||||
    elif torch.version.xpu is not None:
 | 
			
		||||
    elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
 | 
			
		||||
        version = torch.version.xpu
 | 
			
		||||
        compute_framework = f"xpu{version[0:4]}{version[5:6]}"
 | 
			
		||||
    elif _get_privateuse_backend_name() == "npu":
 | 
			
		||||
        from torch_npu.utils.collect_env import get_cann_version  # type: ignore[import-not-found]
 | 
			
		||||
 | 
			
		||||
        cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
 | 
			
		||||
        compute_framework = f"cann{cann_major}{cann_minor}"
 | 
			
		||||
    else:
 | 
			
		||||
        raise AssertionError(
 | 
			
		||||
            "Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
 | 
			
		||||
        )
 | 
			
		||||
        compute_framework = "cpu"
 | 
			
		||||
 | 
			
		||||
    torch_version = parse(torch.__version__)
 | 
			
		||||
    cpu = platform.machine()
 | 
			
		||||
@ -61,9 +74,11 @@ def build_variant() -> str:
 | 
			
		||||
    if os == "darwin":
 | 
			
		||||
        cpu = "aarch64" if cpu == "arm64" else cpu
 | 
			
		||||
        return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
 | 
			
		||||
    elif os == "windows":
 | 
			
		||||
        cpu = "x86_64" if cpu == "AMD64" else cpu
 | 
			
		||||
        return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
 | 
			
		||||
 | 
			
		||||
    cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
 | 
			
		||||
 | 
			
		||||
    return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -95,6 +110,7 @@ def install_kernel(
 | 
			
		||||
    revision: str,
 | 
			
		||||
    local_files_only: bool = False,
 | 
			
		||||
    variant_locks: Optional[Dict[str, VariantLock]] = None,
 | 
			
		||||
    user_agent: Optional[Union[str, dict]] = None,
 | 
			
		||||
) -> Tuple[str, Path]:
 | 
			
		||||
    """
 | 
			
		||||
    Download a kernel for the current environment to the cache.
 | 
			
		||||
@ -110,6 +126,8 @@ def install_kernel(
 | 
			
		||||
            Whether to only use local files and not download from the Hub.
 | 
			
		||||
        variant_locks (`Dict[str, VariantLock]`, *optional*):
 | 
			
		||||
            Optional dictionary of variant locks for validation.
 | 
			
		||||
        user_agent (`Union[str, dict]`, *optional*):
 | 
			
		||||
            The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
 | 
			
		||||
@ -117,6 +135,7 @@ def install_kernel(
 | 
			
		||||
    package_name = package_name_from_repo_id(repo_id)
 | 
			
		||||
    variant = build_variant()
 | 
			
		||||
    universal_variant = universal_build_variant()
 | 
			
		||||
    user_agent = _get_user_agent(user_agent=user_agent)
 | 
			
		||||
    repo_path = Path(
 | 
			
		||||
        snapshot_download(
 | 
			
		||||
            repo_id,
 | 
			
		||||
@ -124,6 +143,7 @@ def install_kernel(
 | 
			
		||||
            cache_dir=CACHE_DIR,
 | 
			
		||||
            revision=revision,
 | 
			
		||||
            local_files_only=local_files_only,
 | 
			
		||||
            user_agent=user_agent,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -200,7 +220,10 @@ def install_kernel_all_variants(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_kernel(
 | 
			
		||||
    repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
 | 
			
		||||
    repo_id: str,
 | 
			
		||||
    revision: Optional[str] = None,
 | 
			
		||||
    version: Optional[str] = None,
 | 
			
		||||
    user_agent: Optional[Union[str, dict]] = None,
 | 
			
		||||
) -> ModuleType:
 | 
			
		||||
    """
 | 
			
		||||
    Load a kernel from the kernel hub.
 | 
			
		||||
@ -216,6 +239,8 @@ def get_kernel(
 | 
			
		||||
        version (`str`, *optional*):
 | 
			
		||||
            The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
 | 
			
		||||
            Cannot be used together with `revision`.
 | 
			
		||||
        user_agent (`Union[str, dict]`, *optional*):
 | 
			
		||||
            The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `ModuleType`: The imported kernel module.
 | 
			
		||||
@ -232,7 +257,9 @@ def get_kernel(
 | 
			
		||||
        ```
 | 
			
		||||
    """
 | 
			
		||||
    revision = select_revision_or_version(repo_id, revision, version)
 | 
			
		||||
    package_name, package_path = install_kernel(repo_id, revision=revision)
 | 
			
		||||
    package_name, package_path = install_kernel(
 | 
			
		||||
        repo_id, revision=revision, user_agent=user_agent
 | 
			
		||||
    )
 | 
			
		||||
    return import_from_path(package_name, package_path / package_name / "__init__.py")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -488,3 +515,29 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
 | 
			
		||||
 | 
			
		||||
def package_name_from_repo_id(repo_id: str) -> str:
 | 
			
		||||
    return repo_id.split("/")[-1].replace("-", "_")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_user_agent(
 | 
			
		||||
    user_agent: Optional[Union[dict, str]] = None,
 | 
			
		||||
) -> Union[None, dict, str]:
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    from . import __version__
 | 
			
		||||
 | 
			
		||||
    if os.getenv("DISABLE_TELEMETRY", "false").upper() in ENV_VARS_TRUE_VALUES:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    if user_agent is None:
 | 
			
		||||
        user_agent = {}
 | 
			
		||||
    if isinstance(user_agent, dict):
 | 
			
		||||
        user_agent.update(
 | 
			
		||||
            {
 | 
			
		||||
                "kernels": __version__,
 | 
			
		||||
                "torch": torch.__version__,
 | 
			
		||||
                "build_variant": build_variant(),
 | 
			
		||||
                "file_type": "kernel",
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
    elif isinstance(user_agent, str):
 | 
			
		||||
        user_agent += f"; kernels/{__version__}; torch/{torch.__version__}; build_variant/{build_variant()}; file_type/kernel"
 | 
			
		||||
    return user_agent
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,8 @@ import sys
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from kernels.utils import _get_privateuse_backend_name
 | 
			
		||||
 | 
			
		||||
has_cuda = (
 | 
			
		||||
    hasattr(torch.version, "cuda")
 | 
			
		||||
    and torch.version.cuda is not None
 | 
			
		||||
@ -18,6 +20,7 @@ has_xpu = (
 | 
			
		||||
    and torch.version.xpu is not None
 | 
			
		||||
    and torch.xpu.device_count() > 0
 | 
			
		||||
)
 | 
			
		||||
has_npu = _get_privateuse_backend_name() == "npu"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytest_addoption(parser):
 | 
			
		||||
@ -37,5 +40,7 @@ def pytest_runtest_setup(item):
 | 
			
		||||
        pytest.skip("skipping macOS-only test on non-macOS platform")
 | 
			
		||||
    if "xpu_only" in item.keywords and not has_xpu:
 | 
			
		||||
        pytest.skip("skipping XPU-only test on host without XPU")
 | 
			
		||||
    if "npu_only" in item.keywords and not has_npu:
 | 
			
		||||
        pytest.skip("skipping NPU-only test on host without NPU")
 | 
			
		||||
    if "token" in item.keywords and not item.config.getoption("--token"):
 | 
			
		||||
        pytest.skip("need --token option to run this test")
 | 
			
		||||
 | 
			
		||||
@ -1,82 +1,70 @@
 | 
			
		||||
[
 | 
			
		||||
  {
 | 
			
		||||
    "repo_id": "kernels-community/activation",
 | 
			
		||||
    "sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
 | 
			
		||||
    "sha": "83046852be158d525114f68513cd79fd88911b37",
 | 
			
		||||
    "variants": {
 | 
			
		||||
      "torch25-cxx11-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx11-cu121-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx11-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx98-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx98-cu121-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx98-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
 | 
			
		||||
        "hash": "sha256-e34965c814c4c092fcb634ebadefe82ea9a05b98343f8ebdefa7305dcc05359e",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
 | 
			
		||||
        "hash": "sha256-5f92b35922b37224a416398a39a29b7e5f1aca1df17d5c69f1b9e9cdb7033561",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu128-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
 | 
			
		||||
        "hash": "sha256-125967cb23bacd2cec443799f184ac08247dfff33f5027e54ee16d3779ca5986",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu128-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
 | 
			
		||||
        "hash": "sha256-496a84c99d7035a1b6f0ea1c026b751c3a2677956f4c1be546d3cc1505a5fdbb",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch28-cxx11-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-f0775a30ffa290c90aba3a41037e3ca91edb15b4a9367561fafd5f25455e117a",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch28-cxx11-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-081995e6230f306bdf6111186618794f2411cf0ffd9b4800330df60b4ebe1927",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch28-cxx11-cu128-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-b937fef62a0c1cd71ab98490b651c473577af209b9a3e2a6b452350283d8812c",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch28-cxx11-cu128-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-a3915686cc58641a3361ece63ab77b33e9d30315dea12547e4bda008d8810a01",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch28-cxx11-cu129-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-a24dca8e998f88be42491921c9df89d88a6112ca630acd2efc2dd34a64b91fcb",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch28-cxx11-cu129-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-df6c70a70f425db2f68b86561c6f93c5675c1d5e5d058766d88ab17472229907",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch29-cxx11-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-c120011c201072b4cfd70c2ba2d45c2f05337feaf604ddec3c6c4987def33ab3",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch29-cxx11-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-765a7f3279009979be4001a23c5c70e5e6ab9553098d67886731a5275a6d4b32",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch29-cxx11-cu128-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-266d057a9cd82b872a0e02f09ac5e2660fcffcf9a7b7fa1fa8ff33dc19c0f5c2",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch29-cxx11-cu128-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-6850e594ba4588f289b5904eb88eda5a41870ee20a3bf1586f3268307caf4b53",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch29-cxx11-cu130-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-23741b935462b53bdf868f8d1c9c8cff5f02f71ea3b0550df41dc8b030b0b474",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch29-cxx11-cu130-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-b884ae792dc1eada071f31645add0c2c76d479864f25aebcdd8318b675aaaf29",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,7 @@ def test_load_locked():
 | 
			
		||||
    load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_layer_locked():
 | 
			
		||||
    project_dir = Path(__file__).parent / "layer_locking"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,11 +7,12 @@ from pathlib import Path
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from huggingface_hub import model_info
 | 
			
		||||
from huggingface_hub import delete_repo, model_info, list_repo_refs
 | 
			
		||||
 | 
			
		||||
from kernels.cli import upload_kernels
 | 
			
		||||
 | 
			
		||||
REPO_ID = "kernels-test/kernels-upload-test"
 | 
			
		||||
REPO_ID = "valid_org/kernels-upload-test"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PY_CONTENT = """\
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
@ -29,6 +30,7 @@ class UploadArgs:
 | 
			
		||||
    kernel_dir: None
 | 
			
		||||
    repo_id: None
 | 
			
		||||
    private: False
 | 
			
		||||
    branch: None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def next_filename(path: Path) -> Path:
 | 
			
		||||
@ -68,7 +70,38 @@ def get_filenames_from_a_repo(repo_id: str) -> List[str]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.token
 | 
			
		||||
@pytest.mark.is_staging_test
 | 
			
		||||
@pytest.mark.parametrize("branch", (None, "foo"))
 | 
			
		||||
def test_kernel_upload_works_as_expected(branch):
 | 
			
		||||
    with tempfile.TemporaryDirectory() as tmpdir:
 | 
			
		||||
        path = f"{tmpdir}/build/torch-universal/upload_test"
 | 
			
		||||
        build_dir = Path(path)
 | 
			
		||||
        build_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
        script_path = build_dir / "foo.py"
 | 
			
		||||
        script_path.write_text(PY_CONTENT)
 | 
			
		||||
        upload_kernels(UploadArgs(tmpdir, REPO_ID, False, branch))
 | 
			
		||||
 | 
			
		||||
    repo_filenames = get_filenames_from_a_repo(REPO_ID)
 | 
			
		||||
    assert any(str(script_path.name) for f in repo_filenames)
 | 
			
		||||
 | 
			
		||||
    if branch is not None:
 | 
			
		||||
        refs = list_repo_refs(repo_id=REPO_ID)
 | 
			
		||||
        assert any(ref_branch.name == branch for ref_branch in refs.branches)
 | 
			
		||||
 | 
			
		||||
    delete_repo(repo_id=REPO_ID)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.token
 | 
			
		||||
@pytest.mark.is_staging_test
 | 
			
		||||
def test_kernel_upload_deletes_as_expected():
 | 
			
		||||
    with tempfile.TemporaryDirectory() as tmpdir:
 | 
			
		||||
        path = f"{tmpdir}/build/torch-universal/upload_test"
 | 
			
		||||
        build_dir = Path(path)
 | 
			
		||||
        build_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
        script_path = build_dir / "foo_2025.py"
 | 
			
		||||
        script_path.write_text(PY_CONTENT)
 | 
			
		||||
        upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None))
 | 
			
		||||
 | 
			
		||||
    repo_filenames = get_filenames_from_a_repo(REPO_ID)
 | 
			
		||||
    filename_to_change = get_filename_to_change(repo_filenames)
 | 
			
		||||
 | 
			
		||||
@ -79,10 +112,11 @@ def test_kernel_upload_deletes_as_expected():
 | 
			
		||||
        changed_filename = next_filename(Path(filename_to_change))
 | 
			
		||||
        script_path = build_dir / changed_filename
 | 
			
		||||
        script_path.write_text(PY_CONTENT)
 | 
			
		||||
        upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
 | 
			
		||||
        upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None))
 | 
			
		||||
 | 
			
		||||
    repo_filenames = get_filenames_from_a_repo(REPO_ID)
 | 
			
		||||
    assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
 | 
			
		||||
    assert not any(
 | 
			
		||||
        str(filename_to_change) in k for k in repo_filenames
 | 
			
		||||
    ), f"{repo_filenames=}"
 | 
			
		||||
    delete_repo(repo_id=REPO_ID)
 | 
			
		||||
 | 
			
		||||
@ -21,14 +21,21 @@ from kernels.layer import (
 | 
			
		||||
    _KERNEL_MAPPING,
 | 
			
		||||
    _validate_layer,
 | 
			
		||||
)
 | 
			
		||||
from kernels.utils import install_kernel
 | 
			
		||||
from kernels.utils import (
 | 
			
		||||
    _get_privateuse_backend_name,
 | 
			
		||||
    install_kernel,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        Device(type="cuda"): LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
        )
 | 
			
		||||
        ),
 | 
			
		||||
        "npu": LayerRepository(
 | 
			
		||||
            repo_id="kernels-ext-npu/SwiGlu",
 | 
			
		||||
            layer_name="SwiGlu",
 | 
			
		||||
        ),
 | 
			
		||||
    },
 | 
			
		||||
    "SiluAndMulNoCompile": {
 | 
			
		||||
        "cuda": LayerRepository(
 | 
			
		||||
@ -122,8 +129,10 @@ def device():
 | 
			
		||||
        return "cuda"
 | 
			
		||||
    elif hasattr(torch, "xpu") and torch.xpu.is_available():
 | 
			
		||||
        return "xpu"
 | 
			
		||||
    elif _get_privateuse_backend_name() == "npu":
 | 
			
		||||
        return "npu"
 | 
			
		||||
 | 
			
		||||
    pytest.skip("No CUDA or XPU")
 | 
			
		||||
    pytest.skip("No CUDA, NPU or XPU")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_arg_kinds():
 | 
			
		||||
@ -204,10 +213,33 @@ def test_hub_forward_xpu():
 | 
			
		||||
    assert rms_norm_with_kernel.n_calls == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.npu_only
 | 
			
		||||
def test_hub_forward_npu():
 | 
			
		||||
    torch.manual_seed(0)
 | 
			
		||||
 | 
			
		||||
    silu_and_mul = SiluAndMul()
 | 
			
		||||
    X = torch.randn((32, 64), device="npu")
 | 
			
		||||
    Y = silu_and_mul(X)
 | 
			
		||||
 | 
			
		||||
    silu_and_mul_with_kernel = kernelize(
 | 
			
		||||
        SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE
 | 
			
		||||
    )
 | 
			
		||||
    Y_kernel = silu_and_mul_with_kernel(X)
 | 
			
		||||
 | 
			
		||||
    torch.testing.assert_close(Y_kernel, Y)
 | 
			
		||||
 | 
			
		||||
    assert silu_and_mul.n_calls == 1
 | 
			
		||||
    assert silu_and_mul_with_kernel.n_calls == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(
 | 
			
		||||
    hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
 | 
			
		||||
    reason="Skip on xpu devices",
 | 
			
		||||
)
 | 
			
		||||
@pytest.mark.skipif(
 | 
			
		||||
    _get_privateuse_backend_name() == "npu",
 | 
			
		||||
    reason="Skip on npu devices",
 | 
			
		||||
)
 | 
			
		||||
def test_rocm_kernel_mapping():
 | 
			
		||||
    """Test that ROCm shorthand device mapping works correctly."""
 | 
			
		||||
    kernel_layer_mapping = {
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user