mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-28 17:54:27 +08:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a30e82182e | |||
| a7f3b2e8ed | |||
| a6ab5d83ba | |||
| 4f9f1abfb9 | |||
| f94b7780a6 | |||
| bd28883775 | |||
| 498429e322 | |||
| 09c991af4b |
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -63,7 +63,7 @@ jobs:
|
||||
- name: Check README generation
|
||||
# For now, just checks that generation doesn't fail.
|
||||
run: |
|
||||
uv run kernels generate-readme kernels-community/triton-layer-norm --revision docs
|
||||
uv run kernels generate-readme kernels-community/triton-layer-norm
|
||||
|
||||
- name: Import check without torch
|
||||
run: |
|
||||
|
||||
@ -37,8 +37,14 @@ to resolve the version constraints.
|
||||
## Native Python module
|
||||
|
||||
Kernels will typically contain a native Python module with precompiled
|
||||
compute kernels and bindings. This module must fulfill the following
|
||||
requirements:
|
||||
compute kernels and bindings. This module must fulfill the requirements
|
||||
outlined in this section. For all operating systems, a kernel must not
|
||||
have dynamic library dependencies outside:
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
|
||||
### Linux
|
||||
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
@ -50,12 +56,18 @@ requirements:
|
||||
- CXXABI 1.3.11
|
||||
- GCC 7.0.0
|
||||
|
||||
These requirement can be checked with the ABI checker (see below).
|
||||
These requirement can be checked with the ABI checker (see below).
|
||||
|
||||
- No dynamic library dependencies outside:
|
||||
### macOS
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
- macOS deployment target 15.0.
|
||||
- Metal 3.0 (`-std=metal3.0`).
|
||||
|
||||
The ABI3 requirement can be checked with the ABI checker (see below).
|
||||
|
||||
### ABI checker
|
||||
|
||||
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
||||
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
||||
|
||||
7
flake.lock
generated
7
flake.lock
generated
@ -58,16 +58,15 @@
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1749025620,
|
||||
"narHash": "sha256-V/r5KOp8FRC5n3MINDzTeS3pZz57SasFVzx12WQRQ8U=",
|
||||
"lastModified": 1750775451,
|
||||
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
|
||||
"owner": "huggingface",
|
||||
"repo": "hf-nix",
|
||||
"rev": "7ab84ffad440c530162f528a96fa062530a6c8e4",
|
||||
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "torch-cxx11",
|
||||
"repo": "hf-nix",
|
||||
"type": "github"
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
inputs = {
|
||||
hf-nix.url = "github:huggingface/hf-nix/torch-cxx11";
|
||||
hf-nix.url = "github:huggingface/hf-nix";
|
||||
nixpkgs.follows = "hf-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
@ -16,7 +16,7 @@
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
inherit (hf-nix.lib) config;
|
||||
config = hf-nix.lib.config system;
|
||||
overlays = [
|
||||
hf-nix.overlays.default
|
||||
];
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.6.0"
|
||||
version = "0.6.2"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
|
||||
@ -17,6 +17,87 @@ _RE_RETURNTYPE = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def _extract_description_before_tags(docstring_mdx: str) -> str:
|
||||
"""Extract the description part of a docstring before any tags."""
|
||||
params_pos = docstring_mdx.find("<parameters>")
|
||||
returns_pos = docstring_mdx.find("<returns>")
|
||||
returntype_pos = docstring_mdx.find("<returntype>")
|
||||
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
|
||||
|
||||
if positions:
|
||||
first_tag_pos = min(positions)
|
||||
return docstring_mdx[:first_tag_pos].strip()
|
||||
else:
|
||||
return docstring_mdx.strip()
|
||||
|
||||
|
||||
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
|
||||
"""Print the parameters section from a docstring."""
|
||||
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
||||
if matches:
|
||||
header = "#" * header_level
|
||||
print(f"\n{header} Parameters")
|
||||
for match in matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
|
||||
|
||||
def _print_returns_section(
|
||||
docstring_mdx: str, *, context_name: str, header_level: int
|
||||
) -> None:
|
||||
"""Print the returns section from a docstring."""
|
||||
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
||||
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
||||
|
||||
if return_matches or returntype_matches:
|
||||
header = "#" * header_level
|
||||
print(f"\n{header} Returns")
|
||||
|
||||
if returntype_matches:
|
||||
if len(returntype_matches) > 1:
|
||||
raise ValueError(
|
||||
f"More than one <returntype> tag found in docstring for {context_name}"
|
||||
)
|
||||
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
|
||||
|
||||
if return_matches:
|
||||
for match in return_matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
|
||||
|
||||
def _get_docstring(obj, use_dict_check: bool = False) -> str:
|
||||
"""Get docstring from an object, with fallback to default message."""
|
||||
# Check whether the class/method itself has docs and not just
|
||||
# the superclass.
|
||||
if use_dict_check:
|
||||
has_doc = obj.__dict__.get("__doc__", None) is not None
|
||||
else:
|
||||
has_doc = getattr(obj, "__doc__", None) is not None
|
||||
|
||||
# We use inspect.getdoc because it does normalization.
|
||||
doc = inspect.getdoc(obj)
|
||||
|
||||
return doc if has_doc and doc is not None else "No documentation available."
|
||||
|
||||
|
||||
def _process_and_print_docstring(
|
||||
docstring: str, *, kernel_name: str, context_name: str, header_level: int
|
||||
) -> None:
|
||||
"""Convert docstring to MDX and print description, parameters, and returns sections."""
|
||||
docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
docstring, page_info={"package_name": kernel_name}
|
||||
)
|
||||
|
||||
# Print the description
|
||||
description = _extract_description_before_tags(docstring_mdx)
|
||||
print(f"\n{description}")
|
||||
|
||||
# Print parameters and returns sections
|
||||
_print_parameters_section(docstring_mdx, header_level=header_level)
|
||||
_print_returns_section(
|
||||
docstring_mdx, context_name=context_name, header_level=header_level
|
||||
)
|
||||
|
||||
|
||||
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
||||
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
|
||||
kernel_name = repo_id.split("/")[-1].replace("-", "_")
|
||||
@ -24,9 +105,10 @@ def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
||||
generate_metadata(kernel_module)
|
||||
generate_kernel_doc(kernel_module, kernel_name)
|
||||
generate_function_doc(kernel_module, kernel_name)
|
||||
generate_layers_doc(kernel_module, kernel_name)
|
||||
|
||||
|
||||
def generate_metadata(module: ModuleType):
|
||||
def generate_metadata(module: ModuleType) -> None:
|
||||
metadata = getattr(module, "__kernel_metadata__", {})
|
||||
if "tags" not in metadata:
|
||||
metadata["tags"] = ["kernel"]
|
||||
@ -39,7 +121,7 @@ def generate_metadata(module: ModuleType):
|
||||
print("---")
|
||||
|
||||
|
||||
def generate_kernel_doc(module: ModuleType, kernel_name: str):
|
||||
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
|
||||
docstring = module.__doc__.strip() if module.__doc__ is not None else None
|
||||
if docstring:
|
||||
title, rest = docstring.split("\n", 1)
|
||||
@ -49,76 +131,112 @@ def generate_kernel_doc(module: ModuleType, kernel_name: str):
|
||||
)
|
||||
|
||||
|
||||
def generate_function_doc(kernel_module, kernel_name):
|
||||
functions_info = []
|
||||
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
||||
# Do not include imported functions.
|
||||
if func.__module__ == kernel_module.__name__:
|
||||
# Exclude private functions.
|
||||
if not name.startswith("_"):
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
docstring = inspect.getdoc(func) or "No documentation available."
|
||||
functions_info.append((name, sig, docstring))
|
||||
except ValueError:
|
||||
print(
|
||||
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||
print("\n## Functions")
|
||||
|
||||
if not functions_info:
|
||||
print(
|
||||
"\nNo public top-level functions.",
|
||||
)
|
||||
return
|
||||
# Track if we found any functions
|
||||
found_functions = False
|
||||
|
||||
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
||||
# Do not include imported functions.
|
||||
if func.__module__ != kernel_module.__name__:
|
||||
continue
|
||||
|
||||
# Exclude private functions.
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
found_functions = True
|
||||
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
docstring = _get_docstring(func)
|
||||
except ValueError:
|
||||
print(
|
||||
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
for name, sig, docstring in functions_info:
|
||||
print(f"\n### Function `{name}`")
|
||||
print(f"\n`{sig}`")
|
||||
|
||||
docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
docstring, page_info={"package_name": kernel_name}
|
||||
_process_and_print_docstring(
|
||||
docstring, kernel_name=kernel_name, context_name=name, header_level=3
|
||||
)
|
||||
|
||||
params_pos = docstring_mdx.find("<parameters>")
|
||||
returns_pos = docstring_mdx.find("<returns>")
|
||||
returntype_pos = docstring_mdx.find("<returntype>")
|
||||
positions = [
|
||||
pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1
|
||||
]
|
||||
if not found_functions:
|
||||
print("\nNo public top-level functions.")
|
||||
|
||||
if positions:
|
||||
first_tag_pos = min(positions)
|
||||
# The function description is anything before the first tag.
|
||||
print(f"\n{docstring_mdx[:first_tag_pos].strip()}")
|
||||
else:
|
||||
print(f"\n{docstring_mdx.strip()}")
|
||||
|
||||
# Extract parameters
|
||||
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
||||
if matches:
|
||||
print("\n### Parameters")
|
||||
for match in matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||
# Check if layers module is available
|
||||
layers_module = getattr(kernel_module, "layers", None)
|
||||
if layers_module is None:
|
||||
return
|
||||
|
||||
# Extract return information
|
||||
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
||||
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
||||
print("\n## Layers")
|
||||
|
||||
if return_matches or returntype_matches:
|
||||
print("\n### Returns", file=sys.stdout)
|
||||
# Track if we found any classes
|
||||
found_classes = False
|
||||
|
||||
if returntype_matches:
|
||||
if len(returntype_matches) > 1:
|
||||
raise ValueError(
|
||||
f"More than one <returntype> tag found in docstring for {name} in {kernel_module.__name__}"
|
||||
)
|
||||
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
|
||||
# Exclude classes that were imported.
|
||||
if cls.__module__ != layers_module.__name__:
|
||||
continue
|
||||
|
||||
found_classes = True
|
||||
|
||||
try:
|
||||
# Get docstring, but not from superclasses.
|
||||
class_docstring = _get_docstring(cls, use_dict_check=True)
|
||||
except Exception:
|
||||
print(
|
||||
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"\n### Class `{class_name}`")
|
||||
|
||||
# Always print class description (helper handles conversion and formatting)
|
||||
class_docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
class_docstring, page_info={"package_name": kernel_name}
|
||||
)
|
||||
description = _extract_description_before_tags(class_docstring_mdx)
|
||||
print(f"\n{description}")
|
||||
|
||||
# Document methods
|
||||
print("\n#### Methods")
|
||||
|
||||
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
|
||||
# Note: also skip __init__, since extension layers cannot have a constructor.
|
||||
if method_name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Skip methods from superclasses.
|
||||
if method_name not in cls.__dict__:
|
||||
continue
|
||||
|
||||
try:
|
||||
sig = inspect.signature(method)
|
||||
method_docstring = _get_docstring(method)
|
||||
except ValueError:
|
||||
print(
|
||||
f"\n**Type**: {returntype_matches[0][0].strip()}", file=sys.stdout
|
||||
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
if return_matches:
|
||||
for match in return_matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
print(f"\n##### Method `{method_name}`")
|
||||
print(f"\n`{sig}`")
|
||||
|
||||
_process_and_print_docstring(
|
||||
method_docstring,
|
||||
kernel_name=kernel_name,
|
||||
context_name=method_name,
|
||||
header_level=6,
|
||||
)
|
||||
|
||||
if not found_classes:
|
||||
print("\nNo layers defined.")
|
||||
|
||||
@ -55,6 +55,7 @@ def build_variant() -> str:
|
||||
os = platform.system().lower()
|
||||
|
||||
if os == "darwin":
|
||||
cpu = "aarch64" if cpu == "arm64" else cpu
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
|
||||
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
|
||||
Reference in New Issue
Block a user