mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-11-04 14:14:31 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			compile-no
			...
			release-0.
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0a3c828359 | 
							
								
								
									
										8
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							@ -55,11 +55,10 @@ jobs:
 | 
			
		||||
          uv run pytest tests
 | 
			
		||||
 | 
			
		||||
      - name: Run staging tests
 | 
			
		||||
        env:
 | 
			
		||||
        env: 
 | 
			
		||||
          HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}
 | 
			
		||||
        run: |
 | 
			
		||||
          HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/
 | 
			
		||||
        if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0'
 | 
			
		||||
 | 
			
		||||
      - name: Check kernel conversion
 | 
			
		||||
        run: |
 | 
			
		||||
@ -73,11 +72,6 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          uv run kernels generate-readme kernels-community/triton-layer-norm
 | 
			
		||||
 | 
			
		||||
      - name: Check kernel check
 | 
			
		||||
        run: |
 | 
			
		||||
          uv pip install kernel-abi-check
 | 
			
		||||
          kernels check kernels-community/activation
 | 
			
		||||
 | 
			
		||||
      - name: Import check without torch
 | 
			
		||||
        run: |
 | 
			
		||||
          uv pip uninstall torch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								Makefile
									
									
									
									
									
								
							@ -1,8 +0,0 @@
 | 
			
		||||
.PHONY: style
 | 
			
		||||
 | 
			
		||||
export check_dirs := src examples tests
 | 
			
		||||
 | 
			
		||||
style:
 | 
			
		||||
	black ${check_dirs}
 | 
			
		||||
	isort ${check_dirs}
 | 
			
		||||
	ruff check ${check_dirs} --fix
 | 
			
		||||
@ -2,24 +2,6 @@
 | 
			
		||||
 | 
			
		||||
## Main Functions
 | 
			
		||||
 | 
			
		||||
### kernels check
 | 
			
		||||
 | 
			
		||||
You can use `kernels check` to test compliance of a kernel on the Hub.
 | 
			
		||||
This currently checks that the kernel:
 | 
			
		||||
 | 
			
		||||
- Supports the currently-required Python ABI version.
 | 
			
		||||
- Works on supported operating system versions.
 | 
			
		||||
 | 
			
		||||
For example:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
$ kernels check kernels-community/flash-attn3
 | 
			
		||||
Checking variant: torch28-cxx11-cu128-aarch64-linux
 | 
			
		||||
  🐍 Python ABI 3.9 compatible
 | 
			
		||||
  🐧 manylinux_2_28 compatible
 | 
			
		||||
[...]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### kernels to-wheel
 | 
			
		||||
 | 
			
		||||
We strongly recommend downloading kernels from the Hub using the `kernels`
 | 
			
		||||
@ -56,3 +38,4 @@ your kernel builds to the Hub.
 | 
			
		||||
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
 | 
			
		||||
  being uploaded, it will attempt to delete the files existing under it.
 | 
			
		||||
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,16 +46,6 @@ have dynamic library dependencies outside:
 | 
			
		||||
- Torch;
 | 
			
		||||
- CUDA/ROCm libraries installed as dependencies of Torch.
 | 
			
		||||
 | 
			
		||||
## Compatibility with torch.compile
 | 
			
		||||
 | 
			
		||||
The Kernel Hub also encourages to write the kernels in a `torch.compile`
 | 
			
		||||
compliant way. This helps to ensure that the kernels are compatible with
 | 
			
		||||
`torch.compile` without introducing any graph breaks and triggering 
 | 
			
		||||
recompilation which can limit the benefits of compilation.
 | 
			
		||||
 | 
			
		||||
[Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and 
 | 
			
		||||
recompilation triggers during `torch.compile`.
 | 
			
		||||
 | 
			
		||||
### Linux
 | 
			
		||||
 | 
			
		||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
 | 
			
		||||
 | 
			
		||||
@ -20,11 +20,11 @@ activation.gelu_fast(y, x)
 | 
			
		||||
print("Kernel successfully executed")
 | 
			
		||||
 | 
			
		||||
# Check results
 | 
			
		||||
expected = torch.tensor(
 | 
			
		||||
    [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
 | 
			
		||||
    device="cuda:0",
 | 
			
		||||
    dtype=torch.float16,
 | 
			
		||||
)
 | 
			
		||||
expected = torch.tensor([
 | 
			
		||||
    [0.8408, 1.9551, 2.9961],
 | 
			
		||||
    [4.0000, 5.0000, 6.0000],
 | 
			
		||||
    [7.0000, 8.0000, 9.0000]
 | 
			
		||||
], device='cuda:0', dtype=torch.float16)
 | 
			
		||||
assert torch.allclose(y, expected)
 | 
			
		||||
 | 
			
		||||
print("Calculated values are exact")
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,6 @@
 | 
			
		||||
      in
 | 
			
		||||
      {
 | 
			
		||||
        formatter = pkgs.nixfmt-tree;
 | 
			
		||||
        packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {};
 | 
			
		||||
        devShells = with pkgs; rec {
 | 
			
		||||
          default = mkShell {
 | 
			
		||||
            nativeBuildInputs = [
 | 
			
		||||
@ -41,7 +40,6 @@
 | 
			
		||||
              ++ (with python3.pkgs; [
 | 
			
		||||
                docutils
 | 
			
		||||
                huggingface-hub
 | 
			
		||||
                (callPackage ./nix/kernel-abi-check.nix {})
 | 
			
		||||
                mktestdocs
 | 
			
		||||
                pytest
 | 
			
		||||
                pytest-benchmark
 | 
			
		||||
 | 
			
		||||
@ -1,27 +0,0 @@
 | 
			
		||||
{
 | 
			
		||||
  buildPythonPackage,
 | 
			
		||||
  fetchPypi,
 | 
			
		||||
  rustPlatform,
 | 
			
		||||
}:
 | 
			
		||||
 | 
			
		||||
buildPythonPackage rec {
 | 
			
		||||
  pname = "kernel-abi-check";
 | 
			
		||||
  version = "0.6.2";
 | 
			
		||||
 | 
			
		||||
  src = fetchPypi {
 | 
			
		||||
    inherit version;
 | 
			
		||||
    pname = "kernel_abi_check";
 | 
			
		||||
    hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys=";
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  cargoDeps = rustPlatform.fetchCargoVendor {
 | 
			
		||||
    inherit pname version src sourceRoot;
 | 
			
		||||
    hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A=";
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  sourceRoot = "kernel_abi_check-${version}/bindings/python";
 | 
			
		||||
 | 
			
		||||
  pyproject = true;
 | 
			
		||||
 | 
			
		||||
  nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ];
 | 
			
		||||
}
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
[project]
 | 
			
		||||
name = "kernels"
 | 
			
		||||
version = "0.10.2.dev0"
 | 
			
		||||
version = "0.10.2"
 | 
			
		||||
description = "Download compute kernels"
 | 
			
		||||
authors = [
 | 
			
		||||
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
 | 
			
		||||
@ -34,7 +34,6 @@ dev = [
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[project.optional-dependencies]
 | 
			
		||||
abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
 | 
			
		||||
torch = ["torch"]
 | 
			
		||||
docs = [
 | 
			
		||||
  "hf-doc-builder",
 | 
			
		||||
@ -46,9 +45,6 @@ kernels = "kernels.cli:main"
 | 
			
		||||
[project.entry-points."egg_info.writers"]
 | 
			
		||||
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
 | 
			
		||||
 | 
			
		||||
[tool.isort]
 | 
			
		||||
profile = "black"
 | 
			
		||||
line_length = 119
 | 
			
		||||
 | 
			
		||||
[tool.ruff]
 | 
			
		||||
exclude = [
 | 
			
		||||
@ -75,4 +71,4 @@ line-length = 119
 | 
			
		||||
# Ignored rules:
 | 
			
		||||
# "E501" -> line length violation
 | 
			
		||||
lint.ignore = ["E501"]
 | 
			
		||||
lint.select = ["E", "F", "W"]
 | 
			
		||||
lint.select = ["E", "F", "I", "W"]
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,5 @@ markers =
 | 
			
		||||
    rocm_only: marks tests that should only run on hosts with ROCm GPUs
 | 
			
		||||
    darwin_only: marks tests that should only run on macOS
 | 
			
		||||
    xpu_only: marks tests that should only run on hosts with Intel XPUs
 | 
			
		||||
    npu_only: marks tests that should only run on Ascend NPUs
 | 
			
		||||
    token: enable tests that require a write token
 | 
			
		||||
    is_staging_test: Marks tests that should only run on a staging environment
 | 
			
		||||
 | 
			
		||||
@ -1,141 +0,0 @@
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from kernels.utils import CACHE_DIR
 | 
			
		||||
from kernel_abi_check import (
 | 
			
		||||
    BinaryFormat,
 | 
			
		||||
    IncompatibleMacOSVersion,
 | 
			
		||||
    ObjectFile,
 | 
			
		||||
    IncompatibleAbi3Symbol,
 | 
			
		||||
    NonAbi3Symbol,
 | 
			
		||||
    IncompatibleManylinuxSymbol,
 | 
			
		||||
    MissingMacOSVersion,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_kernel(
 | 
			
		||||
    *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
 | 
			
		||||
):
 | 
			
		||||
    variants_path = (
 | 
			
		||||
        Path(
 | 
			
		||||
            snapshot_download(
 | 
			
		||||
                repo_id,
 | 
			
		||||
                allow_patterns=["build/*"],
 | 
			
		||||
                cache_dir=CACHE_DIR,
 | 
			
		||||
                revision=revision,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        / "build"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    for variant_path in variants_path.iterdir():
 | 
			
		||||
        if not variant_path.is_dir():
 | 
			
		||||
            print(
 | 
			
		||||
                f"⛔ `build/` must only contain directories, found: {variant_path.name}",
 | 
			
		||||
                file=sys.stderr,
 | 
			
		||||
            )
 | 
			
		||||
            has_issues = True
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        print(f"Checking variant: {variant_path.name}", file=sys.stderr)
 | 
			
		||||
 | 
			
		||||
        indent = 2
 | 
			
		||||
 | 
			
		||||
        for dylib_path in variant_path.rglob("*.so"):
 | 
			
		||||
            print_with_indent(
 | 
			
		||||
                indent,
 | 
			
		||||
                f"Dynamic library {dylib_path.relative_to(variant_path)}:",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            o = ObjectFile(dylib_path)
 | 
			
		||||
            has_issues |= check_abi3(o, python_abi, indent + 2)
 | 
			
		||||
 | 
			
		||||
            # TODO: also check operating system
 | 
			
		||||
            if o.format() == BinaryFormat.ELF:
 | 
			
		||||
                has_issues |= check_manylinux(o, manylinux, indent + 2)
 | 
			
		||||
            elif o.format() == BinaryFormat.MACH_O:
 | 
			
		||||
                has_issues |= check_macos(o, macos, indent + 2)
 | 
			
		||||
 | 
			
		||||
    if has_issues:
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool:
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    violations = object_file.check_python_abi(python_abi)
 | 
			
		||||
    if violations != []:
 | 
			
		||||
        has_issues = True
 | 
			
		||||
        print_with_indent(
 | 
			
		||||
            indent,
 | 
			
		||||
            f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:",
 | 
			
		||||
        )
 | 
			
		||||
        for violation in violations:
 | 
			
		||||
            if isinstance(violation, IncompatibleAbi3Symbol):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"{violation.name}: {violation.version_added}",
 | 
			
		||||
                )
 | 
			
		||||
            elif isinstance(violation, NonAbi3Symbol):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"{violation.name}",
 | 
			
		||||
                )
 | 
			
		||||
    else:
 | 
			
		||||
        print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible")
 | 
			
		||||
 | 
			
		||||
    return has_issues
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool:
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    violations = object_file.check_macos(macos)
 | 
			
		||||
    if violations != []:
 | 
			
		||||
        has_issues = True
 | 
			
		||||
        print_with_indent(
 | 
			
		||||
            indent,
 | 
			
		||||
            f"⛔ Found incompatibility with macOS {macos}:",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for violation in violations:
 | 
			
		||||
            if isinstance(violation, MissingMacOSVersion):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    "shared library does not contain macOS version",
 | 
			
		||||
                )
 | 
			
		||||
            elif isinstance(violation, IncompatibleMacOSVersion):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"shared library requires macOS {violation.version}",
 | 
			
		||||
                )
 | 
			
		||||
    else:
 | 
			
		||||
        print_with_indent(indent, f"🍏 compatible with macOS {macos}")
 | 
			
		||||
 | 
			
		||||
    return has_issues
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool:
 | 
			
		||||
    has_issues = False
 | 
			
		||||
    violations = object_file.check_manylinux(manylinux)
 | 
			
		||||
    if violations != []:
 | 
			
		||||
        has_issues = True
 | 
			
		||||
        print_with_indent(
 | 
			
		||||
            indent,
 | 
			
		||||
            f"⛔ Found symbols that are incompatible with {manylinux}:",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for violation in violations:
 | 
			
		||||
            if isinstance(violation, IncompatibleManylinuxSymbol):
 | 
			
		||||
                print_with_indent(
 | 
			
		||||
                    indent + 3,
 | 
			
		||||
                    f"{violation.name}_{violation.dep}: {violation.version}",
 | 
			
		||||
                )
 | 
			
		||||
    else:
 | 
			
		||||
        print_with_indent(indent, f"🐧 {manylinux} compatible")
 | 
			
		||||
 | 
			
		||||
    return has_issues
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_with_indent(indent: int, message: str):
 | 
			
		||||
    print(f"{' ' * indent}{message}", file=sys.stderr)
 | 
			
		||||
@ -20,31 +20,6 @@ def main():
 | 
			
		||||
    )
 | 
			
		||||
    subparsers = parser.add_subparsers(required=True)
 | 
			
		||||
 | 
			
		||||
    check_parser = subparsers.add_parser("check", help="Check a kernel for compliance")
 | 
			
		||||
    check_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
 | 
			
		||||
    check_parser.add_argument(
 | 
			
		||||
        "--revision",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="main",
 | 
			
		||||
        help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
 | 
			
		||||
    )
 | 
			
		||||
    check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0")
 | 
			
		||||
    check_parser.add_argument(
 | 
			
		||||
        "--manylinux", type=str, help="Manylinux version", default="manylinux_2_28"
 | 
			
		||||
    )
 | 
			
		||||
    check_parser.add_argument(
 | 
			
		||||
        "--python-abi", type=str, help="Python ABI version", default="3.9"
 | 
			
		||||
    )
 | 
			
		||||
    check_parser.set_defaults(
 | 
			
		||||
        func=lambda args: check_kernel(
 | 
			
		||||
            macos=args.macos,
 | 
			
		||||
            manylinux=args.manylinux,
 | 
			
		||||
            python_abi=args.python_abi,
 | 
			
		||||
            repo_id=args.repo_id,
 | 
			
		||||
            revision=args.revision,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    download_parser = subparsers.add_parser("download", help="Download locked kernels")
 | 
			
		||||
    download_parser.add_argument(
 | 
			
		||||
        "project_dir",
 | 
			
		||||
@ -230,24 +205,3 @@ class _JSONEncoder(json.JSONEncoder):
 | 
			
		||||
        if dataclasses.is_dataclass(o):
 | 
			
		||||
            return dataclasses.asdict(o)
 | 
			
		||||
        return super().default(o)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_kernel(
 | 
			
		||||
    *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
 | 
			
		||||
):
 | 
			
		||||
    try:
 | 
			
		||||
        import kernels.check
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        print(
 | 
			
		||||
            "`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check",
 | 
			
		||||
            file=sys.stderr,
 | 
			
		||||
        )
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
    kernels.check.check_kernel(
 | 
			
		||||
        macos=macos,
 | 
			
		||||
        manylinux=manylinux,
 | 
			
		||||
        python_abi=python_abi,
 | 
			
		||||
        repo_id=repo_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -87,7 +87,7 @@ class Device:
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        type (`str`):
 | 
			
		||||
            The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu").
 | 
			
		||||
            The device type (e.g., "cuda", "mps", "rocm", "xpu").
 | 
			
		||||
        properties ([`CUDAProperties`], *optional*):
 | 
			
		||||
            Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
 | 
			
		||||
 | 
			
		||||
@ -109,9 +109,6 @@ class Device:
 | 
			
		||||
 | 
			
		||||
        # XPU device (e.g., Intel(R) Data Center GPU Max 1550)
 | 
			
		||||
        xpu_device = Device(type="xpu")
 | 
			
		||||
 | 
			
		||||
        # NPU device (Huawei Ascend)
 | 
			
		||||
        npu_device = Device(type="npu")
 | 
			
		||||
        ```
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -133,8 +130,6 @@ class Device:
 | 
			
		||||
            return _MPSRepos()
 | 
			
		||||
        elif self.type == "xpu":
 | 
			
		||||
            return _XPURepos()
 | 
			
		||||
        elif self.type == "npu":
 | 
			
		||||
            return _NPURepos()
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Unknown device type: {self.type}")
 | 
			
		||||
 | 
			
		||||
@ -477,26 +472,6 @@ class _XPURepos(_DeviceRepos):
 | 
			
		||||
        self._repos = repos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _NPURepos(_DeviceRepos):
 | 
			
		||||
    _repos: Dict[Mode, LayerRepositoryProtocol]
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._repos = {}
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def repos(
 | 
			
		||||
        self,
 | 
			
		||||
    ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
 | 
			
		||||
        return self._repos
 | 
			
		||||
 | 
			
		||||
    def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
 | 
			
		||||
        if device.type != "npu":
 | 
			
		||||
            raise ValueError(f"Device type must be 'npu', got {device.type}")
 | 
			
		||||
 | 
			
		||||
        self._repos = repos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _MPSRepos(_DeviceRepos):
 | 
			
		||||
    _repos: Dict[Mode, LayerRepositoryProtocol]
 | 
			
		||||
 | 
			
		||||
@ -581,7 +556,7 @@ class _ROCMRepos(_DeviceRepos):
 | 
			
		||||
 | 
			
		||||
def _validate_device_type(device_type: str) -> None:
 | 
			
		||||
    """Validate that the device type is supported."""
 | 
			
		||||
    supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"}
 | 
			
		||||
    supported_devices = {"cuda", "rocm", "mps", "xpu"}
 | 
			
		||||
    if device_type not in supported_devices:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
 | 
			
		||||
@ -839,7 +814,7 @@ def kernelize(
 | 
			
		||||
            `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
 | 
			
		||||
            `torch.compile`.
 | 
			
		||||
        device (`Union[str, torch.device]`, *optional*):
 | 
			
		||||
            The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu".
 | 
			
		||||
            The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
 | 
			
		||||
            The device type will be inferred from the model parameters when not provided.
 | 
			
		||||
        use_fallback (`bool`, *optional*, defaults to `True`):
 | 
			
		||||
            Whether to use the original forward method of modules when no compatible kernel could be found.
 | 
			
		||||
@ -863,7 +838,7 @@ def kernelize(
 | 
			
		||||
                return F.silu(x[..., :d]) * x[..., d:]
 | 
			
		||||
 | 
			
		||||
        mapping = {
 | 
			
		||||
            "SiluAndMul": {
 | 
			
		||||
            "LayerNorm": {
 | 
			
		||||
                "cuda": LayerRepository(
 | 
			
		||||
                    repo_id="kernels-community/activation",
 | 
			
		||||
                    layer_name="SiluAndMul",
 | 
			
		||||
 | 
			
		||||
@ -35,14 +35,6 @@ def _get_cache_dir() -> Optional[str]:
 | 
			
		||||
CACHE_DIR: Optional[str] = _get_cache_dir()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_privateuse_backend_name() -> Optional[str]:
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    if hasattr(torch._C, "_get_privateuse1_backend_name"):
 | 
			
		||||
        return torch._C._get_privateuse1_backend_name()
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_variant() -> str:
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
@ -57,14 +49,9 @@ def build_variant() -> str:
 | 
			
		||||
    elif torch.version.xpu is not None:
 | 
			
		||||
        version = torch.version.xpu
 | 
			
		||||
        compute_framework = f"xpu{version[0:4]}{version[5:6]}"
 | 
			
		||||
    elif _get_privateuse_backend_name() == "npu":
 | 
			
		||||
        from torch_npu.utils.collect_env import get_cann_version  # type: ignore[import-not-found]
 | 
			
		||||
 | 
			
		||||
        cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
 | 
			
		||||
        compute_framework = f"cann{cann_major}{cann_minor}"
 | 
			
		||||
    else:
 | 
			
		||||
        raise AssertionError(
 | 
			
		||||
            "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled."
 | 
			
		||||
            "Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    torch_version = parse(torch.__version__)
 | 
			
		||||
 | 
			
		||||
@ -3,8 +3,6 @@ import sys
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from kernels.utils import _get_privateuse_backend_name
 | 
			
		||||
 | 
			
		||||
has_cuda = (
 | 
			
		||||
    hasattr(torch.version, "cuda")
 | 
			
		||||
    and torch.version.cuda is not None
 | 
			
		||||
@ -20,7 +18,6 @@ has_xpu = (
 | 
			
		||||
    and torch.version.xpu is not None
 | 
			
		||||
    and torch.xpu.device_count() > 0
 | 
			
		||||
)
 | 
			
		||||
has_npu = _get_privateuse_backend_name() == "npu"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytest_addoption(parser):
 | 
			
		||||
@ -40,7 +37,5 @@ def pytest_runtest_setup(item):
 | 
			
		||||
        pytest.skip("skipping macOS-only test on non-macOS platform")
 | 
			
		||||
    if "xpu_only" in item.keywords and not has_xpu:
 | 
			
		||||
        pytest.skip("skipping XPU-only test on host without XPU")
 | 
			
		||||
    if "npu_only" in item.keywords and not has_npu:
 | 
			
		||||
        pytest.skip("skipping NPU-only test on host without NPU")
 | 
			
		||||
    if "token" in item.keywords and not item.config.getoption("--token"):
 | 
			
		||||
        pytest.skip("need --token option to run this test")
 | 
			
		||||
 | 
			
		||||
@ -35,7 +35,6 @@ def test_load_locked():
 | 
			
		||||
    load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_layer_locked():
 | 
			
		||||
    project_dir = Path(__file__).parent / "layer_locking"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,21 +21,14 @@ from kernels.layer import (
 | 
			
		||||
    _KERNEL_MAPPING,
 | 
			
		||||
    _validate_layer,
 | 
			
		||||
)
 | 
			
		||||
from kernels.utils import (
 | 
			
		||||
    _get_privateuse_backend_name,
 | 
			
		||||
    install_kernel,
 | 
			
		||||
)
 | 
			
		||||
from kernels.utils import install_kernel
 | 
			
		||||
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        Device(type="cuda"): LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
        ),
 | 
			
		||||
        "npu": LayerRepository(
 | 
			
		||||
            repo_id="kernels-ext-npu/SwiGlu",
 | 
			
		||||
            layer_name="SwiGlu",
 | 
			
		||||
        ),
 | 
			
		||||
        )
 | 
			
		||||
    },
 | 
			
		||||
    "SiluAndMulNoCompile": {
 | 
			
		||||
        "cuda": LayerRepository(
 | 
			
		||||
@ -129,10 +122,8 @@ def device():
 | 
			
		||||
        return "cuda"
 | 
			
		||||
    elif hasattr(torch, "xpu") and torch.xpu.is_available():
 | 
			
		||||
        return "xpu"
 | 
			
		||||
    elif _get_privateuse_backend_name() == "npu":
 | 
			
		||||
        return "npu"
 | 
			
		||||
 | 
			
		||||
    pytest.skip("No CUDA, NPU or XPU")
 | 
			
		||||
    pytest.skip("No CUDA or XPU")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_arg_kinds():
 | 
			
		||||
@ -213,33 +204,10 @@ def test_hub_forward_xpu():
 | 
			
		||||
    assert rms_norm_with_kernel.n_calls == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.npu_only
 | 
			
		||||
def test_hub_forward_npu():
 | 
			
		||||
    torch.manual_seed(0)
 | 
			
		||||
 | 
			
		||||
    silu_and_mul = SiluAndMul()
 | 
			
		||||
    X = torch.randn((32, 64), device="npu")
 | 
			
		||||
    Y = silu_and_mul(X)
 | 
			
		||||
 | 
			
		||||
    silu_and_mul_with_kernel = kernelize(
 | 
			
		||||
        SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE
 | 
			
		||||
    )
 | 
			
		||||
    Y_kernel = silu_and_mul_with_kernel(X)
 | 
			
		||||
 | 
			
		||||
    torch.testing.assert_close(Y_kernel, Y)
 | 
			
		||||
 | 
			
		||||
    assert silu_and_mul.n_calls == 1
 | 
			
		||||
    assert silu_and_mul_with_kernel.n_calls == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(
 | 
			
		||||
    hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
 | 
			
		||||
    reason="Skip on xpu devices",
 | 
			
		||||
)
 | 
			
		||||
@pytest.mark.skipif(
 | 
			
		||||
    _get_privateuse_backend_name() == "npu",
 | 
			
		||||
    reason="Skip on npu devices",
 | 
			
		||||
)
 | 
			
		||||
def test_rocm_kernel_mapping():
 | 
			
		||||
    """Test that ROCm shorthand device mapping works correctly."""
 | 
			
		||||
    kernel_layer_mapping = {
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user