Compare commits

..

8 Commits

7 changed files with 203 additions and 73 deletions

View File

@ -63,7 +63,7 @@ jobs:
- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
uv run kernels generate-readme kernels-community/triton-layer-norm --revision docs
uv run kernels generate-readme kernels-community/triton-layer-norm
- name: Import check without torch
run: |

View File

@ -37,8 +37,14 @@ to resolve the version constraints.
## Native Python module
Kernels will typically contain a native Python module with precompiled
compute kernels and bindings. This module must fulfill the following
requirements:
compute kernels and bindings. This module must fulfill the requirements
outlined in this section. For all operating systems, a kernel must not
have dynamic library dependencies outside:
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
### Linux
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
@ -50,12 +56,18 @@ requirements:
- CXXABI 1.3.11
- GCC 7.0.0
These requirement can be checked with the ABI checker (see below).
These requirement can be checked with the ABI checker (see below).
- No dynamic library dependencies outside:
### macOS
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- macOS deployment target 15.0.
- Metal 3.0 (`-std=metal3.0`).
The ABI3 requirement can be checked with the ABI checker (see below).
### ABI checker
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):

7
flake.lock generated
View File

@ -58,16 +58,15 @@
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1749025620,
"narHash": "sha256-V/r5KOp8FRC5n3MINDzTeS3pZz57SasFVzx12WQRQ8U=",
"lastModified": 1750775451,
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "7ab84ffad440c530162f528a96fa062530a6c8e4",
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "torch-cxx11",
"repo": "hf-nix",
"type": "github"
}

View File

@ -1,6 +1,6 @@
{
inputs = {
hf-nix.url = "github:huggingface/hf-nix/torch-cxx11";
hf-nix.url = "github:huggingface/hf-nix";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
@ -16,7 +16,7 @@
let
pkgs = import nixpkgs {
inherit system;
inherit (hf-nix.lib) config;
config = hf-nix.lib.config system;
overlays = [
hf-nix.overlays.default
];

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.6.0"
version = "0.6.2"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },

View File

@ -17,6 +17,87 @@ _RE_RETURNTYPE = re.compile(
)
def _extract_description_before_tags(docstring_mdx: str) -> str:
"""Extract the description part of a docstring before any tags."""
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
if positions:
first_tag_pos = min(positions)
return docstring_mdx[:first_tag_pos].strip()
else:
return docstring_mdx.strip()
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
"""Print the parameters section from a docstring."""
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
header = "#" * header_level
print(f"\n{header} Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def _print_returns_section(
docstring_mdx: str, *, context_name: str, header_level: int
) -> None:
"""Print the returns section from a docstring."""
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
if return_matches or returntype_matches:
header = "#" * header_level
print(f"\n{header} Returns")
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {context_name}"
)
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
def _get_docstring(obj, use_dict_check: bool = False) -> str:
"""Get docstring from an object, with fallback to default message."""
# Check whether the class/method itself has docs and not just
# the superclass.
if use_dict_check:
has_doc = obj.__dict__.get("__doc__", None) is not None
else:
has_doc = getattr(obj, "__doc__", None) is not None
# We use inspect.getdoc because it does normalization.
doc = inspect.getdoc(obj)
return doc if has_doc and doc is not None else "No documentation available."
def _process_and_print_docstring(
docstring: str, *, kernel_name: str, context_name: str, header_level: int
) -> None:
"""Convert docstring to MDX and print description, parameters, and returns sections."""
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
)
# Print the description
description = _extract_description_before_tags(docstring_mdx)
print(f"\n{description}")
# Print parameters and returns sections
_print_parameters_section(docstring_mdx, header_level=header_level)
_print_returns_section(
docstring_mdx, context_name=context_name, header_level=header_level
)
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
kernel_name = repo_id.split("/")[-1].replace("-", "_")
@ -24,9 +105,10 @@ def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
generate_metadata(kernel_module)
generate_kernel_doc(kernel_module, kernel_name)
generate_function_doc(kernel_module, kernel_name)
generate_layers_doc(kernel_module, kernel_name)
def generate_metadata(module: ModuleType):
def generate_metadata(module: ModuleType) -> None:
metadata = getattr(module, "__kernel_metadata__", {})
if "tags" not in metadata:
metadata["tags"] = ["kernel"]
@ -39,7 +121,7 @@ def generate_metadata(module: ModuleType):
print("---")
def generate_kernel_doc(module: ModuleType, kernel_name: str):
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
docstring = module.__doc__.strip() if module.__doc__ is not None else None
if docstring:
title, rest = docstring.split("\n", 1)
@ -49,76 +131,112 @@ def generate_kernel_doc(module: ModuleType, kernel_name: str):
)
def generate_function_doc(kernel_module, kernel_name):
functions_info = []
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ == kernel_module.__name__:
# Exclude private functions.
if not name.startswith("_"):
try:
sig = inspect.signature(func)
docstring = inspect.getdoc(func) or "No documentation available."
functions_info.append((name, sig, docstring))
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
print("\n## Functions")
if not functions_info:
print(
"\nNo public top-level functions.",
)
return
# Track if we found any functions
found_functions = False
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ != kernel_module.__name__:
continue
# Exclude private functions.
if name.startswith("_"):
continue
found_functions = True
try:
sig = inspect.signature(func)
docstring = _get_docstring(func)
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
continue
for name, sig, docstring in functions_info:
print(f"\n### Function `{name}`")
print(f"\n`{sig}`")
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
_process_and_print_docstring(
docstring, kernel_name=kernel_name, context_name=name, header_level=3
)
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [
pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1
]
if not found_functions:
print("\nNo public top-level functions.")
if positions:
first_tag_pos = min(positions)
# The function description is anything before the first tag.
print(f"\n{docstring_mdx[:first_tag_pos].strip()}")
else:
print(f"\n{docstring_mdx.strip()}")
# Extract parameters
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
print("\n### Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
# Check if layers module is available
layers_module = getattr(kernel_module, "layers", None)
if layers_module is None:
return
# Extract return information
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
print("\n## Layers")
if return_matches or returntype_matches:
print("\n### Returns", file=sys.stdout)
# Track if we found any classes
found_classes = False
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {name} in {kernel_module.__name__}"
)
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
# Exclude classes that were imported.
if cls.__module__ != layers_module.__name__:
continue
found_classes = True
try:
# Get docstring, but not from superclasses.
class_docstring = _get_docstring(cls, use_dict_check=True)
except Exception:
print(
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Class `{class_name}`")
# Always print class description (helper handles conversion and formatting)
class_docstring_mdx = convert_rst_docstring_to_mdx(
class_docstring, page_info={"package_name": kernel_name}
)
description = _extract_description_before_tags(class_docstring_mdx)
print(f"\n{description}")
# Document methods
print("\n#### Methods")
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
# Note: also skip __init__, since extension layers cannot have a constructor.
if method_name.startswith("_"):
continue
# Skip methods from superclasses.
if method_name not in cls.__dict__:
continue
try:
sig = inspect.signature(method)
method_docstring = _get_docstring(method)
except ValueError:
print(
f"\n**Type**: {returntype_matches[0][0].strip()}", file=sys.stdout
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
file=sys.stderr,
)
continue
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
print(f"\n##### Method `{method_name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
method_docstring,
kernel_name=kernel_name,
context_name=method_name,
header_level=6,
)
if not found_classes:
print("\nNo layers defined.")

View File

@ -55,6 +55,7 @@ def build_variant() -> str:
os = platform.system().lower()
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"