mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-22 05:48:52 +08:00
Compare commits
1 Commits
compile-no
...
v0.10.2
Author | SHA1 | Date | |
---|---|---|---|
0a3c828359 |
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@ -55,11 +55,10 @@ jobs:
|
||||
uv run pytest tests
|
||||
|
||||
- name: Run staging tests
|
||||
env:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}
|
||||
run: |
|
||||
HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/
|
||||
if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0'
|
||||
|
||||
- name: Check kernel conversion
|
||||
run: |
|
||||
@ -73,11 +72,6 @@ jobs:
|
||||
run: |
|
||||
uv run kernels generate-readme kernels-community/triton-layer-norm
|
||||
|
||||
- name: Check kernel check
|
||||
run: |
|
||||
uv pip install kernel-abi-check
|
||||
kernels check kernels-community/activation
|
||||
|
||||
- name: Import check without torch
|
||||
run: |
|
||||
uv pip uninstall torch
|
||||
|
8
Makefile
8
Makefile
@ -1,8 +0,0 @@
|
||||
.PHONY: style
|
||||
|
||||
export check_dirs := src examples tests
|
||||
|
||||
style:
|
||||
black ${check_dirs}
|
||||
isort ${check_dirs}
|
||||
ruff check ${check_dirs} --fix
|
@ -2,24 +2,6 @@
|
||||
|
||||
## Main Functions
|
||||
|
||||
### kernels check
|
||||
|
||||
You can use `kernels check` to test compliance of a kernel on the Hub.
|
||||
This currently checks that the kernel:
|
||||
|
||||
- Supports the currently-required Python ABI version.
|
||||
- Works on supported operating system versions.
|
||||
|
||||
For example:
|
||||
|
||||
```bash
|
||||
$ kernels check kernels-community/flash-attn3
|
||||
Checking variant: torch28-cxx11-cu128-aarch64-linux
|
||||
🐍 Python ABI 3.9 compatible
|
||||
🐧 manylinux_2_28 compatible
|
||||
[...]
|
||||
```
|
||||
|
||||
### kernels to-wheel
|
||||
|
||||
We strongly recommend downloading kernels from the Hub using the `kernels`
|
||||
@ -56,3 +38,4 @@ your kernel builds to the Hub.
|
||||
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
|
||||
being uploaded, it will attempt to delete the files existing under it.
|
||||
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
|
||||
|
||||
|
@ -46,16 +46,6 @@ have dynamic library dependencies outside:
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
|
||||
## Compatibility with torch.compile
|
||||
|
||||
The Kernel Hub also encourages to write the kernels in a `torch.compile`
|
||||
compliant way. This helps to ensure that the kernels are compatible with
|
||||
`torch.compile` without introducing any graph breaks and triggering
|
||||
recompilation which can limit the benefits of compilation.
|
||||
|
||||
[Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and
|
||||
recompilation triggers during `torch.compile`.
|
||||
|
||||
### Linux
|
||||
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
|
@ -20,11 +20,11 @@ activation.gelu_fast(y, x)
|
||||
print("Kernel successfully executed")
|
||||
|
||||
# Check results
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device="cuda:0",
|
||||
dtype=torch.float16,
|
||||
)
|
||||
expected = torch.tensor([
|
||||
[0.8408, 1.9551, 2.9961],
|
||||
[4.0000, 5.0000, 6.0000],
|
||||
[7.0000, 8.0000, 9.0000]
|
||||
], device='cuda:0', dtype=torch.float16)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
print("Calculated values are exact")
|
||||
|
@ -24,7 +24,6 @@
|
||||
in
|
||||
{
|
||||
formatter = pkgs.nixfmt-tree;
|
||||
packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {};
|
||||
devShells = with pkgs; rec {
|
||||
default = mkShell {
|
||||
nativeBuildInputs = [
|
||||
@ -41,7 +40,6 @@
|
||||
++ (with python3.pkgs; [
|
||||
docutils
|
||||
huggingface-hub
|
||||
(callPackage ./nix/kernel-abi-check.nix {})
|
||||
mktestdocs
|
||||
pytest
|
||||
pytest-benchmark
|
||||
|
@ -1,27 +0,0 @@
|
||||
{
|
||||
buildPythonPackage,
|
||||
fetchPypi,
|
||||
rustPlatform,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "kernel-abi-check";
|
||||
version = "0.6.2";
|
||||
|
||||
src = fetchPypi {
|
||||
inherit version;
|
||||
pname = "kernel_abi_check";
|
||||
hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys=";
|
||||
};
|
||||
|
||||
cargoDeps = rustPlatform.fetchCargoVendor {
|
||||
inherit pname version src sourceRoot;
|
||||
hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A=";
|
||||
};
|
||||
|
||||
sourceRoot = "kernel_abi_check-${version}/bindings/python";
|
||||
|
||||
pyproject = true;
|
||||
|
||||
nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ];
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.10.2.dev0"
|
||||
version = "0.10.2"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
@ -34,7 +34,6 @@ dev = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"]
|
||||
torch = ["torch"]
|
||||
docs = [
|
||||
"hf-doc-builder",
|
||||
@ -46,9 +45,6 @@ kernels = "kernels.cli:main"
|
||||
[project.entry-points."egg_info.writers"]
|
||||
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 119
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
@ -75,4 +71,4 @@ line-length = 119
|
||||
# Ignored rules:
|
||||
# "E501" -> line length violation
|
||||
lint.ignore = ["E501"]
|
||||
lint.select = ["E", "F", "W"]
|
||||
lint.select = ["E", "F", "I", "W"]
|
||||
|
@ -4,6 +4,5 @@ markers =
|
||||
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||
darwin_only: marks tests that should only run on macOS
|
||||
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
||||
npu_only: marks tests that should only run on Ascend NPUs
|
||||
token: enable tests that require a write token
|
||||
is_staging_test: Marks tests that should only run on a staging environment
|
||||
|
@ -1,141 +0,0 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from kernels.utils import CACHE_DIR
|
||||
from kernel_abi_check import (
|
||||
BinaryFormat,
|
||||
IncompatibleMacOSVersion,
|
||||
ObjectFile,
|
||||
IncompatibleAbi3Symbol,
|
||||
NonAbi3Symbol,
|
||||
IncompatibleManylinuxSymbol,
|
||||
MissingMacOSVersion,
|
||||
)
|
||||
|
||||
|
||||
def check_kernel(
|
||||
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
|
||||
):
|
||||
variants_path = (
|
||||
Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=["build/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
/ "build"
|
||||
)
|
||||
|
||||
has_issues = False
|
||||
for variant_path in variants_path.iterdir():
|
||||
if not variant_path.is_dir():
|
||||
print(
|
||||
f"⛔ `build/` must only contain directories, found: {variant_path.name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
has_issues = True
|
||||
continue
|
||||
|
||||
print(f"Checking variant: {variant_path.name}", file=sys.stderr)
|
||||
|
||||
indent = 2
|
||||
|
||||
for dylib_path in variant_path.rglob("*.so"):
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"Dynamic library {dylib_path.relative_to(variant_path)}:",
|
||||
)
|
||||
|
||||
o = ObjectFile(dylib_path)
|
||||
has_issues |= check_abi3(o, python_abi, indent + 2)
|
||||
|
||||
# TODO: also check operating system
|
||||
if o.format() == BinaryFormat.ELF:
|
||||
has_issues |= check_manylinux(o, manylinux, indent + 2)
|
||||
elif o.format() == BinaryFormat.MACH_O:
|
||||
has_issues |= check_macos(o, macos, indent + 2)
|
||||
|
||||
if has_issues:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool:
|
||||
has_issues = False
|
||||
violations = object_file.check_python_abi(python_abi)
|
||||
if violations != []:
|
||||
has_issues = True
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:",
|
||||
)
|
||||
for violation in violations:
|
||||
if isinstance(violation, IncompatibleAbi3Symbol):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"{violation.name}: {violation.version_added}",
|
||||
)
|
||||
elif isinstance(violation, NonAbi3Symbol):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"{violation.name}",
|
||||
)
|
||||
else:
|
||||
print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible")
|
||||
|
||||
return has_issues
|
||||
|
||||
|
||||
def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool:
|
||||
has_issues = False
|
||||
violations = object_file.check_macos(macos)
|
||||
if violations != []:
|
||||
has_issues = True
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"⛔ Found incompatibility with macOS {macos}:",
|
||||
)
|
||||
|
||||
for violation in violations:
|
||||
if isinstance(violation, MissingMacOSVersion):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
"shared library does not contain macOS version",
|
||||
)
|
||||
elif isinstance(violation, IncompatibleMacOSVersion):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"shared library requires macOS {violation.version}",
|
||||
)
|
||||
else:
|
||||
print_with_indent(indent, f"🍏 compatible with macOS {macos}")
|
||||
|
||||
return has_issues
|
||||
|
||||
|
||||
def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool:
|
||||
has_issues = False
|
||||
violations = object_file.check_manylinux(manylinux)
|
||||
if violations != []:
|
||||
has_issues = True
|
||||
print_with_indent(
|
||||
indent,
|
||||
f"⛔ Found symbols that are incompatible with {manylinux}:",
|
||||
)
|
||||
|
||||
for violation in violations:
|
||||
if isinstance(violation, IncompatibleManylinuxSymbol):
|
||||
print_with_indent(
|
||||
indent + 3,
|
||||
f"{violation.name}_{violation.dep}: {violation.version}",
|
||||
)
|
||||
else:
|
||||
print_with_indent(indent, f"🐧 {manylinux} compatible")
|
||||
|
||||
return has_issues
|
||||
|
||||
|
||||
def print_with_indent(indent: int, message: str):
|
||||
print(f"{' ' * indent}{message}", file=sys.stderr)
|
@ -20,31 +20,6 @@ def main():
|
||||
)
|
||||
subparsers = parser.add_subparsers(required=True)
|
||||
|
||||
check_parser = subparsers.add_parser("check", help="Check a kernel for compliance")
|
||||
check_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
|
||||
check_parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default="main",
|
||||
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
|
||||
)
|
||||
check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0")
|
||||
check_parser.add_argument(
|
||||
"--manylinux", type=str, help="Manylinux version", default="manylinux_2_28"
|
||||
)
|
||||
check_parser.add_argument(
|
||||
"--python-abi", type=str, help="Python ABI version", default="3.9"
|
||||
)
|
||||
check_parser.set_defaults(
|
||||
func=lambda args: check_kernel(
|
||||
macos=args.macos,
|
||||
manylinux=args.manylinux,
|
||||
python_abi=args.python_abi,
|
||||
repo_id=args.repo_id,
|
||||
revision=args.revision,
|
||||
)
|
||||
)
|
||||
|
||||
download_parser = subparsers.add_parser("download", help="Download locked kernels")
|
||||
download_parser.add_argument(
|
||||
"project_dir",
|
||||
@ -230,24 +205,3 @@ class _JSONEncoder(json.JSONEncoder):
|
||||
if dataclasses.is_dataclass(o):
|
||||
return dataclasses.asdict(o)
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def check_kernel(
|
||||
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
|
||||
):
|
||||
try:
|
||||
import kernels.check
|
||||
except ImportError:
|
||||
print(
|
||||
"`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
kernels.check.check_kernel(
|
||||
macos=macos,
|
||||
manylinux=manylinux,
|
||||
python_abi=python_abi,
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
)
|
||||
|
@ -87,7 +87,7 @@ class Device:
|
||||
|
||||
Args:
|
||||
type (`str`):
|
||||
The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu").
|
||||
The device type (e.g., "cuda", "mps", "rocm", "xpu").
|
||||
properties ([`CUDAProperties`], *optional*):
|
||||
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.
|
||||
|
||||
@ -109,9 +109,6 @@ class Device:
|
||||
|
||||
# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
|
||||
xpu_device = Device(type="xpu")
|
||||
|
||||
# NPU device (Huawei Ascend)
|
||||
npu_device = Device(type="npu")
|
||||
```
|
||||
"""
|
||||
|
||||
@ -133,8 +130,6 @@ class Device:
|
||||
return _MPSRepos()
|
||||
elif self.type == "xpu":
|
||||
return _XPURepos()
|
||||
elif self.type == "npu":
|
||||
return _NPURepos()
|
||||
else:
|
||||
raise ValueError(f"Unknown device type: {self.type}")
|
||||
|
||||
@ -477,26 +472,6 @@ class _XPURepos(_DeviceRepos):
|
||||
self._repos = repos
|
||||
|
||||
|
||||
class _NPURepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._repos = {}
|
||||
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
return self._repos
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
if device.type != "npu":
|
||||
raise ValueError(f"Device type must be 'npu', got {device.type}")
|
||||
|
||||
self._repos = repos
|
||||
|
||||
|
||||
class _MPSRepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
@ -581,7 +556,7 @@ class _ROCMRepos(_DeviceRepos):
|
||||
|
||||
def _validate_device_type(device_type: str) -> None:
|
||||
"""Validate that the device type is supported."""
|
||||
supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"}
|
||||
supported_devices = {"cuda", "rocm", "mps", "xpu"}
|
||||
if device_type not in supported_devices:
|
||||
raise ValueError(
|
||||
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
|
||||
@ -839,7 +814,7 @@ def kernelize(
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
|
||||
`torch.compile`.
|
||||
device (`Union[str, torch.device]`, *optional*):
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu".
|
||||
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
|
||||
The device type will be inferred from the model parameters when not provided.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the original forward method of modules when no compatible kernel could be found.
|
||||
@ -863,7 +838,7 @@ def kernelize(
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
mapping = {
|
||||
"SiluAndMul": {
|
||||
"LayerNorm": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
|
@ -35,14 +35,6 @@ def _get_cache_dir() -> Optional[str]:
|
||||
CACHE_DIR: Optional[str] = _get_cache_dir()
|
||||
|
||||
|
||||
def _get_privateuse_backend_name() -> Optional[str]:
|
||||
import torch
|
||||
|
||||
if hasattr(torch._C, "_get_privateuse1_backend_name"):
|
||||
return torch._C._get_privateuse1_backend_name()
|
||||
return None
|
||||
|
||||
|
||||
def build_variant() -> str:
|
||||
import torch
|
||||
|
||||
@ -57,14 +49,9 @@ def build_variant() -> str:
|
||||
elif torch.version.xpu is not None:
|
||||
version = torch.version.xpu
|
||||
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
|
||||
elif _get_privateuse_backend_name() == "npu":
|
||||
from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
|
||||
|
||||
cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
|
||||
compute_framework = f"cann{cann_major}{cann_minor}"
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled."
|
||||
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
||||
)
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
|
@ -3,8 +3,6 @@ import sys
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.utils import _get_privateuse_backend_name
|
||||
|
||||
has_cuda = (
|
||||
hasattr(torch.version, "cuda")
|
||||
and torch.version.cuda is not None
|
||||
@ -20,7 +18,6 @@ has_xpu = (
|
||||
and torch.version.xpu is not None
|
||||
and torch.xpu.device_count() > 0
|
||||
)
|
||||
has_npu = _get_privateuse_backend_name() == "npu"
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
@ -40,7 +37,5 @@ def pytest_runtest_setup(item):
|
||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||
if "xpu_only" in item.keywords and not has_xpu:
|
||||
pytest.skip("skipping XPU-only test on host without XPU")
|
||||
if "npu_only" in item.keywords and not has_npu:
|
||||
pytest.skip("skipping NPU-only test on host without NPU")
|
||||
if "token" in item.keywords and not item.config.getoption("--token"):
|
||||
pytest.skip("need --token option to run this test")
|
||||
|
@ -35,7 +35,6 @@ def test_load_locked():
|
||||
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
def test_layer_locked():
|
||||
project_dir = Path(__file__).parent / "layer_locking"
|
||||
|
||||
|
@ -21,21 +21,14 @@ from kernels.layer import (
|
||||
_KERNEL_MAPPING,
|
||||
_validate_layer,
|
||||
)
|
||||
from kernels.utils import (
|
||||
_get_privateuse_backend_name,
|
||||
install_kernel,
|
||||
)
|
||||
from kernels.utils import install_kernel
|
||||
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
"npu": LayerRepository(
|
||||
repo_id="kernels-ext-npu/SwiGlu",
|
||||
layer_name="SwiGlu",
|
||||
),
|
||||
)
|
||||
},
|
||||
"SiluAndMulNoCompile": {
|
||||
"cuda": LayerRepository(
|
||||
@ -129,10 +122,8 @@ def device():
|
||||
return "cuda"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
elif _get_privateuse_backend_name() == "npu":
|
||||
return "npu"
|
||||
|
||||
pytest.skip("No CUDA, NPU or XPU")
|
||||
pytest.skip("No CUDA or XPU")
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@ -213,33 +204,10 @@ def test_hub_forward_xpu():
|
||||
assert rms_norm_with_kernel.n_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.npu_only
|
||||
def test_hub_forward_npu():
|
||||
torch.manual_seed(0)
|
||||
|
||||
silu_and_mul = SiluAndMul()
|
||||
X = torch.randn((32, 64), device="npu")
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = kernelize(
|
||||
SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE
|
||||
)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
assert silu_and_mul.n_calls == 1
|
||||
assert silu_and_mul_with_kernel.n_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
|
||||
reason="Skip on xpu devices",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
_get_privateuse_backend_name() == "npu",
|
||||
reason="Skip on npu devices",
|
||||
)
|
||||
def test_rocm_kernel_mapping():
|
||||
"""Test that ROCm shorthand device mapping works correctly."""
|
||||
kernel_layer_mapping = {
|
||||
|
Reference in New Issue
Block a user