mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-31 19:54:28 +08:00 
			
		
		
		
	Compare commits
	
		
			8 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a30e82182e | |||
| a7f3b2e8ed | |||
| a6ab5d83ba | |||
| 4f9f1abfb9 | |||
| f94b7780a6 | |||
| bd28883775 | |||
| 498429e322 | |||
| 09c991af4b | 
							
								
								
									
										2
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -63,7 +63,7 @@ jobs: | ||||
|       - name: Check README generation | ||||
|         # For now, just checks that generation doesn't fail. | ||||
|         run: | | ||||
|           uv run kernels generate-readme kernels-community/triton-layer-norm --revision docs | ||||
|           uv run kernels generate-readme kernels-community/triton-layer-norm | ||||
|  | ||||
|       - name: Import check without torch | ||||
|         run: | | ||||
|  | ||||
| @ -37,8 +37,14 @@ to resolve the version constraints. | ||||
| ## Native Python module | ||||
|  | ||||
| Kernels will typically contain a native Python module with precompiled | ||||
| compute kernels and bindings. This module must fulfill the following | ||||
| requirements: | ||||
| compute kernels and bindings. This module must fulfill the requirements | ||||
| outlined in this section. For all operating systems, a kernel must not | ||||
| have dynamic library dependencies outside: | ||||
|  | ||||
| - Torch; | ||||
| - CUDA/ROCm libraries installed as dependencies of Torch. | ||||
|  | ||||
| ### Linux | ||||
|  | ||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||
|   for compatibility with Python 3.9 and later. | ||||
| @ -50,12 +56,18 @@ requirements: | ||||
|   - CXXABI 1.3.11 | ||||
|   - GCC 7.0.0 | ||||
|  | ||||
|   These requirement can be checked with the ABI checker (see below). | ||||
| These requirement can be checked with the ABI checker (see below). | ||||
|  | ||||
| - No dynamic library dependencies outside: | ||||
| ### macOS | ||||
|  | ||||
|   - Torch; | ||||
|   - CUDA/ROCm libraries installed as dependencies of Torch. | ||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||
|   for compatibility with Python 3.9 and later. | ||||
| - macOS deployment target 15.0. | ||||
| - Metal 3.0 (`-std=metal3.0`). | ||||
|  | ||||
| The ABI3 requirement can be checked with the ABI checker (see below). | ||||
|  | ||||
| ### ABI checker | ||||
|  | ||||
| The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with | ||||
| [`kernel-abi-check`](https://crates.io/crates/kernel-abi-check): | ||||
|  | ||||
							
								
								
									
										7
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										7
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							| @ -58,16 +58,15 @@ | ||||
|         "nixpkgs": "nixpkgs" | ||||
|       }, | ||||
|       "locked": { | ||||
|         "lastModified": 1749025620, | ||||
|         "narHash": "sha256-V/r5KOp8FRC5n3MINDzTeS3pZz57SasFVzx12WQRQ8U=", | ||||
|         "lastModified": 1750775451, | ||||
|         "narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=", | ||||
|         "owner": "huggingface", | ||||
|         "repo": "hf-nix", | ||||
|         "rev": "7ab84ffad440c530162f528a96fa062530a6c8e4", | ||||
|         "rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "huggingface", | ||||
|         "ref": "torch-cxx11", | ||||
|         "repo": "hf-nix", | ||||
|         "type": "github" | ||||
|       } | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| { | ||||
|   inputs = { | ||||
|     hf-nix.url = "github:huggingface/hf-nix/torch-cxx11"; | ||||
|     hf-nix.url = "github:huggingface/hf-nix"; | ||||
|     nixpkgs.follows = "hf-nix/nixpkgs"; | ||||
|     flake-utils.url = "github:numtide/flake-utils"; | ||||
|   }; | ||||
| @ -16,7 +16,7 @@ | ||||
|       let | ||||
|         pkgs = import nixpkgs { | ||||
|           inherit system; | ||||
|           inherit (hf-nix.lib) config; | ||||
|           config = hf-nix.lib.config system; | ||||
|           overlays = [ | ||||
|             hf-nix.overlays.default | ||||
|           ]; | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| [project] | ||||
| name = "kernels" | ||||
| version = "0.6.0" | ||||
| version = "0.6.2" | ||||
| description = "Download compute kernels" | ||||
| authors = [ | ||||
|   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||
|  | ||||
| @ -17,6 +17,87 @@ _RE_RETURNTYPE = re.compile( | ||||
| ) | ||||
|  | ||||
|  | ||||
| def _extract_description_before_tags(docstring_mdx: str) -> str: | ||||
|     """Extract the description part of a docstring before any tags.""" | ||||
|     params_pos = docstring_mdx.find("<parameters>") | ||||
|     returns_pos = docstring_mdx.find("<returns>") | ||||
|     returntype_pos = docstring_mdx.find("<returntype>") | ||||
|     positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1] | ||||
|  | ||||
|     if positions: | ||||
|         first_tag_pos = min(positions) | ||||
|         return docstring_mdx[:first_tag_pos].strip() | ||||
|     else: | ||||
|         return docstring_mdx.strip() | ||||
|  | ||||
|  | ||||
| def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None: | ||||
|     """Print the parameters section from a docstring.""" | ||||
|     matches = _RE_PARAMETERS.findall(docstring_mdx) | ||||
|     if matches: | ||||
|         header = "#" * header_level | ||||
|         print(f"\n{header} Parameters") | ||||
|         for match in matches: | ||||
|             print(f"\n{match[0].strip()}") | ||||
|  | ||||
|  | ||||
| def _print_returns_section( | ||||
|     docstring_mdx: str, *, context_name: str, header_level: int | ||||
| ) -> None: | ||||
|     """Print the returns section from a docstring.""" | ||||
|     return_matches = _RE_RETURNS.findall(docstring_mdx) | ||||
|     returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx) | ||||
|  | ||||
|     if return_matches or returntype_matches: | ||||
|         header = "#" * header_level | ||||
|         print(f"\n{header} Returns") | ||||
|  | ||||
|         if returntype_matches: | ||||
|             if len(returntype_matches) > 1: | ||||
|                 raise ValueError( | ||||
|                     f"More than one <returntype> tag found in docstring for {context_name}" | ||||
|                 ) | ||||
|             print(f"\n**Type**: {returntype_matches[0][0].strip()}") | ||||
|  | ||||
|         if return_matches: | ||||
|             for match in return_matches: | ||||
|                 print(f"\n{match[0].strip()}") | ||||
|  | ||||
|  | ||||
| def _get_docstring(obj, use_dict_check: bool = False) -> str: | ||||
|     """Get docstring from an object, with fallback to default message.""" | ||||
|     # Check whether the class/method itself has docs and not just | ||||
|     # the superclass. | ||||
|     if use_dict_check: | ||||
|         has_doc = obj.__dict__.get("__doc__", None) is not None | ||||
|     else: | ||||
|         has_doc = getattr(obj, "__doc__", None) is not None | ||||
|  | ||||
|     # We use inspect.getdoc because it does normalization. | ||||
|     doc = inspect.getdoc(obj) | ||||
|  | ||||
|     return doc if has_doc and doc is not None else "No documentation available." | ||||
|  | ||||
|  | ||||
| def _process_and_print_docstring( | ||||
|     docstring: str, *, kernel_name: str, context_name: str, header_level: int | ||||
| ) -> None: | ||||
|     """Convert docstring to MDX and print description, parameters, and returns sections.""" | ||||
|     docstring_mdx = convert_rst_docstring_to_mdx( | ||||
|         docstring, page_info={"package_name": kernel_name} | ||||
|     ) | ||||
|  | ||||
|     # Print the description | ||||
|     description = _extract_description_before_tags(docstring_mdx) | ||||
|     print(f"\n{description}") | ||||
|  | ||||
|     # Print parameters and returns sections | ||||
|     _print_parameters_section(docstring_mdx, header_level=header_level) | ||||
|     _print_returns_section( | ||||
|         docstring_mdx, context_name=context_name, header_level=header_level | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None: | ||||
|     kernel_module = get_kernel(repo_id=repo_id, revision=revision) | ||||
|     kernel_name = repo_id.split("/")[-1].replace("-", "_") | ||||
| @ -24,9 +105,10 @@ def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None: | ||||
|     generate_metadata(kernel_module) | ||||
|     generate_kernel_doc(kernel_module, kernel_name) | ||||
|     generate_function_doc(kernel_module, kernel_name) | ||||
|     generate_layers_doc(kernel_module, kernel_name) | ||||
|  | ||||
|  | ||||
| def generate_metadata(module: ModuleType): | ||||
| def generate_metadata(module: ModuleType) -> None: | ||||
|     metadata = getattr(module, "__kernel_metadata__", {}) | ||||
|     if "tags" not in metadata: | ||||
|         metadata["tags"] = ["kernel"] | ||||
| @ -39,7 +121,7 @@ def generate_metadata(module: ModuleType): | ||||
|     print("---") | ||||
|  | ||||
|  | ||||
| def generate_kernel_doc(module: ModuleType, kernel_name: str): | ||||
| def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None: | ||||
|     docstring = module.__doc__.strip() if module.__doc__ is not None else None | ||||
|     if docstring: | ||||
|         title, rest = docstring.split("\n", 1) | ||||
| @ -49,76 +131,112 @@ def generate_kernel_doc(module: ModuleType, kernel_name: str): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def generate_function_doc(kernel_module, kernel_name): | ||||
|     functions_info = [] | ||||
| def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None: | ||||
|     print("\n## Functions") | ||||
|  | ||||
|     # Track if we found any functions | ||||
|     found_functions = False | ||||
|  | ||||
|     for name, func in inspect.getmembers(kernel_module, inspect.isfunction): | ||||
|         # Do not include imported functions. | ||||
|         if func.__module__ == kernel_module.__name__: | ||||
|         if func.__module__ != kernel_module.__name__: | ||||
|             continue | ||||
|  | ||||
|         # Exclude private functions. | ||||
|             if not name.startswith("_"): | ||||
|         if name.startswith("_"): | ||||
|             continue | ||||
|  | ||||
|         found_functions = True | ||||
|  | ||||
|         try: | ||||
|             sig = inspect.signature(func) | ||||
|                     docstring = inspect.getdoc(func) or "No documentation available." | ||||
|                     functions_info.append((name, sig, docstring)) | ||||
|             docstring = _get_docstring(func) | ||||
|         except ValueError: | ||||
|             print( | ||||
|                 f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}", | ||||
|                 file=sys.stderr, | ||||
|             ) | ||||
|             continue | ||||
|  | ||||
|     print("\n## Functions") | ||||
|  | ||||
|     if not functions_info: | ||||
|         print( | ||||
|             "\nNo public top-level functions.", | ||||
|         ) | ||||
|         return | ||||
|  | ||||
|     for name, sig, docstring in functions_info: | ||||
|         print(f"\n### Function `{name}`") | ||||
|         print(f"\n`{sig}`") | ||||
|  | ||||
|         docstring_mdx = convert_rst_docstring_to_mdx( | ||||
|             docstring, page_info={"package_name": kernel_name} | ||||
|         _process_and_print_docstring( | ||||
|             docstring, kernel_name=kernel_name, context_name=name, header_level=3 | ||||
|         ) | ||||
|  | ||||
|         params_pos = docstring_mdx.find("<parameters>") | ||||
|         returns_pos = docstring_mdx.find("<returns>") | ||||
|         returntype_pos = docstring_mdx.find("<returntype>") | ||||
|         positions = [ | ||||
|             pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1 | ||||
|         ] | ||||
|     if not found_functions: | ||||
|         print("\nNo public top-level functions.") | ||||
|  | ||||
|         if positions: | ||||
|             first_tag_pos = min(positions) | ||||
|             # The function description is anything before the first tag. | ||||
|             print(f"\n{docstring_mdx[:first_tag_pos].strip()}") | ||||
|         else: | ||||
|             print(f"\n{docstring_mdx.strip()}") | ||||
|  | ||||
|         # Extract parameters | ||||
|         matches = _RE_PARAMETERS.findall(docstring_mdx) | ||||
|         if matches: | ||||
|             print("\n### Parameters") | ||||
|             for match in matches: | ||||
|                 print(f"\n{match[0].strip()}") | ||||
| def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None: | ||||
|     # Check if layers module is available | ||||
|     layers_module = getattr(kernel_module, "layers", None) | ||||
|     if layers_module is None: | ||||
|         return | ||||
|  | ||||
|         # Extract return information | ||||
|         return_matches = _RE_RETURNS.findall(docstring_mdx) | ||||
|         returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx) | ||||
|     print("\n## Layers") | ||||
|  | ||||
|         if return_matches or returntype_matches: | ||||
|             print("\n### Returns", file=sys.stdout) | ||||
|     # Track if we found any classes | ||||
|     found_classes = False | ||||
|  | ||||
|             if returntype_matches: | ||||
|                 if len(returntype_matches) > 1: | ||||
|                     raise ValueError( | ||||
|                         f"More than one <returntype> tag found in docstring for {name} in {kernel_module.__name__}" | ||||
|                     ) | ||||
|     for class_name, cls in inspect.getmembers(layers_module, inspect.isclass): | ||||
|         # Exclude classes that were imported. | ||||
|         if cls.__module__ != layers_module.__name__: | ||||
|             continue | ||||
|  | ||||
|         found_classes = True | ||||
|  | ||||
|         try: | ||||
|             # Get docstring, but not from superclasses. | ||||
|             class_docstring = _get_docstring(cls, use_dict_check=True) | ||||
|         except Exception: | ||||
|             print( | ||||
|                     f"\n**Type**: {returntype_matches[0][0].strip()}", file=sys.stdout | ||||
|                 f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}", | ||||
|                 file=sys.stderr, | ||||
|             ) | ||||
|             continue | ||||
|  | ||||
|         print(f"\n### Class `{class_name}`") | ||||
|  | ||||
|         # Always print class description (helper handles conversion and formatting) | ||||
|         class_docstring_mdx = convert_rst_docstring_to_mdx( | ||||
|             class_docstring, page_info={"package_name": kernel_name} | ||||
|         ) | ||||
|         description = _extract_description_before_tags(class_docstring_mdx) | ||||
|         print(f"\n{description}") | ||||
|  | ||||
|         # Document methods | ||||
|         print("\n#### Methods") | ||||
|  | ||||
|         for method_name, method in inspect.getmembers(cls, inspect.isfunction): | ||||
|             # Note: also skip __init__, since extension layers cannot have a constructor. | ||||
|             if method_name.startswith("_"): | ||||
|                 continue | ||||
|  | ||||
|             # Skip methods from superclasses. | ||||
|             if method_name not in cls.__dict__: | ||||
|                 continue | ||||
|  | ||||
|             try: | ||||
|                 sig = inspect.signature(method) | ||||
|                 method_docstring = _get_docstring(method) | ||||
|             except ValueError: | ||||
|                 print( | ||||
|                     f"Warning: Could not retrieve signature for {method_name} in {class_name}", | ||||
|                     file=sys.stderr, | ||||
|                 ) | ||||
|                 continue | ||||
|  | ||||
|             print(f"\n##### Method `{method_name}`") | ||||
|             print(f"\n`{sig}`") | ||||
|  | ||||
|             _process_and_print_docstring( | ||||
|                 method_docstring, | ||||
|                 kernel_name=kernel_name, | ||||
|                 context_name=method_name, | ||||
|                 header_level=6, | ||||
|             ) | ||||
|  | ||||
|             if return_matches: | ||||
|                 for match in return_matches: | ||||
|                     print(f"\n{match[0].strip()}") | ||||
|     if not found_classes: | ||||
|         print("\nNo layers defined.") | ||||
|  | ||||
| @ -55,6 +55,7 @@ def build_variant() -> str: | ||||
|     os = platform.system().lower() | ||||
|  | ||||
|     if os == "darwin": | ||||
|         cpu = "aarch64" if cpu == "arm64" else cpu | ||||
|         return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" | ||||
|  | ||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	