mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 21:38:52 +08:00
Compare commits
1 Commits
v0.10.0
...
add-compat
Author | SHA1 | Date | |
---|---|---|---|
cc66a53025 |
@ -4,6 +4,8 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from kernels.compat import tomllib
|
||||
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||
@ -36,6 +38,23 @@ def main():
|
||||
)
|
||||
lock_parser.set_defaults(func=lock_kernels)
|
||||
|
||||
# Add a new compatibility command
|
||||
compat_parser = subparsers.add_parser(
|
||||
"compatibility", help="Show kernel build compatibility"
|
||||
)
|
||||
compat_parser.add_argument(
|
||||
"repo_id",
|
||||
type=str,
|
||||
help="The repository ID of the kernel (e.g., 'kernels-community/activation')",
|
||||
)
|
||||
compat_parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default="main",
|
||||
help="The revision of the kernel (default: main)",
|
||||
)
|
||||
compat_parser.set_defaults(func=check_compatibility)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
@ -91,6 +110,53 @@ def lock_kernels(args):
|
||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||
|
||||
|
||||
def check_compatibility(args):
|
||||
"""Check build compatibility for a kernel by reading its build.toml file."""
|
||||
try:
|
||||
# Download only the build.toml file from the repository
|
||||
build_toml_path = hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename="build.toml",
|
||||
revision=args.revision,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
f"Error: Could not find build.toml in repository {args.repo_id}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Parse the build.toml file
|
||||
try:
|
||||
with open(build_toml_path, "rb") as f:
|
||||
content = f.read().decode("utf-8")
|
||||
|
||||
# Simple check for compatibility without full parsing
|
||||
is_universal = "language" in content and "python" in content
|
||||
has_cuda = "cuda-capabilities" in content
|
||||
has_rocm = "rocm-archs" in content
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading build.toml: {str(e)}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Print the compatibility
|
||||
print(f"Kernel: {args.repo_id}")
|
||||
print("Compatibility: ", end="")
|
||||
|
||||
if is_universal:
|
||||
print("universal")
|
||||
else:
|
||||
compatibilities = []
|
||||
if has_cuda:
|
||||
compatibilities.append("cuda")
|
||||
if has_rocm:
|
||||
compatibilities.append("rocm")
|
||||
print(", ".join(compatibilities) if compatibilities else "unknown")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class _JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if dataclasses.is_dataclass(o):
|
||||
|
Reference in New Issue
Block a user