Compare commits

...

1 Commits

Author SHA1 Message Date
cc66a53025 feat: add compatibility command for rocm, cuda, universal 2025-04-07 14:07:15 -04:00

View File

@ -4,6 +4,8 @@ import json
import sys
from pathlib import Path
from huggingface_hub import hf_hub_download
from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants
@ -36,6 +38,23 @@ def main():
)
lock_parser.set_defaults(func=lock_kernels)
# Add a new compatibility command
compat_parser = subparsers.add_parser(
"compatibility", help="Show kernel build compatibility"
)
compat_parser.add_argument(
"repo_id",
type=str,
help="The repository ID of the kernel (e.g., 'kernels-community/activation')",
)
compat_parser.add_argument(
"--revision",
type=str,
default="main",
help="The revision of the kernel (default: main)",
)
compat_parser.set_defaults(func=check_compatibility)
args = parser.parse_args()
args.func(args)
@ -91,6 +110,53 @@ def lock_kernels(args):
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
def check_compatibility(args):
"""Check build compatibility for a kernel by reading its build.toml file."""
try:
# Download only the build.toml file from the repository
build_toml_path = hf_hub_download(
repo_id=args.repo_id,
filename="build.toml",
revision=args.revision,
)
except Exception:
print(
f"Error: Could not find build.toml in repository {args.repo_id}.",
file=sys.stderr,
)
sys.exit(1)
# Parse the build.toml file
try:
with open(build_toml_path, "rb") as f:
content = f.read().decode("utf-8")
# Simple check for compatibility without full parsing
is_universal = "language" in content and "python" in content
has_cuda = "cuda-capabilities" in content
has_rocm = "rocm-archs" in content
except Exception as e:
print(f"Error reading build.toml: {str(e)}", file=sys.stderr)
sys.exit(1)
# Print the compatibility
print(f"Kernel: {args.repo_id}")
print("Compatibility: ", end="")
if is_universal:
print("universal")
else:
compatibilities = []
if has_cuda:
compatibilities.append("cuda")
if has_rocm:
compatibilities.append("rocm")
print(", ".join(compatibilities) if compatibilities else "unknown")
return 0
class _JSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):