mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 21:38:52 +08:00
Compare commits
22 Commits
v0.3.1
...
fix-comman
Author | SHA1 | Date | |
---|---|---|---|
03a8662f7f | |||
cf530c283a | |||
437f910336 | |||
6f1a6067c8 | |||
1d14abcef0 | |||
6fd2112e22 | |||
70f56ff856 | |||
7178b0b86c | |||
0bbf90a564 | |||
27d6ffcb80 | |||
f7bd21438b | |||
6174febb4b | |||
ff55bc201b | |||
3808108d62 | |||
c4a16ef462 | |||
9762794dd2 | |||
b7d6867c52 | |||
fbcd0f2ebd | |||
5af46eca94 | |||
747dd66876 | |||
920590a592 | |||
5208ac4be5 |
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -52,3 +52,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests
|
||||
|
||||
- name: Import check without torch
|
||||
run: |
|
||||
uv pip uninstall torch
|
||||
python -c "import kernels"
|
||||
|
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
12
README.md
12
README.md
@ -1,5 +1,16 @@
|
||||
# kernels
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
||||
|
||||
</p>
|
||||
</div>
|
||||
<hr/>
|
||||
|
||||
The Kernel Hub allows Python libraries and applications to load compute
|
||||
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||
@ -47,6 +58,7 @@ the Hub.
|
||||
|
||||
- [Using layers](docs/layers.md)
|
||||
- [Locking kernel versions](docs/locking.md)
|
||||
- [Environment variables](docs/env.md)
|
||||
- [Using kernels in a Docker container](docs/docker.md)
|
||||
- [Kernel requirements](docs/kernel-requirements.md)
|
||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||
|
10
docs/env.md
Normal file
10
docs/env.md
Normal file
@ -0,0 +1,10 @@
|
||||
# Environment variables
|
||||
|
||||
## `KERNELS_CACHE`
|
||||
|
||||
The directory to use as the local kernel cache. If not set, the cache
|
||||
of the `huggingface_hub` package is used.
|
||||
|
||||
## `DISABLE_KERNEL_MAPPING`
|
||||
|
||||
Disables kernel mappings for [`layers`](layers.md).
|
@ -38,6 +38,12 @@ as the repository (replacing `-` by `_`). For instance, kernels in the
|
||||
`build/<variant>/activation`. This directory
|
||||
must be a Python package with an `__init__.py` file.
|
||||
|
||||
## Versioning
|
||||
|
||||
Kernels are versioned on the Hub using Git tags. Version tags must be of
|
||||
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
||||
to resolve the version constraints.
|
||||
|
||||
## Native Python module
|
||||
|
||||
Kernels will typically contain a native Python module with precompiled
|
||||
@ -46,16 +52,31 @@ requirements:
|
||||
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
- Compatible with glibc 2.27 or later. This means that no symbols
|
||||
from later versions must be used. To archive this, the module should
|
||||
be built against this glibc version. **Warning:** libgcc must also be
|
||||
built against glibc 2.27 to avoid leaking symbols.
|
||||
- No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols
|
||||
must be static.
|
||||
- No dynamic library dependencies outside Torch or CUDA libraries
|
||||
installed as dependencies of Torch.
|
||||
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
||||
This means that the extension **must not** use symbols versions higher than:
|
||||
|
||||
(These requirements will be updated as new PyTorch versions are released.)
|
||||
- GLIBC 2.28
|
||||
- GLIBCXX 3.4.24
|
||||
- CXXABI 1.3.11
|
||||
- GCC 7.0.0
|
||||
|
||||
These requirement can be checked with the ABI checker (see below).
|
||||
|
||||
- No dynamic library dependencies outside:
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
|
||||
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
||||
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
||||
|
||||
```bash
|
||||
|
||||
$ cargo install kernel-abi-check
|
||||
$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
|
||||
🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
|
||||
✅ No compatibility issues found
|
||||
```
|
||||
|
||||
## Torch extension
|
||||
|
||||
@ -98,10 +119,17 @@ requirements:
|
||||
- The `forward` method has a signature that is compatible with the
|
||||
`forward` method that it is extending.
|
||||
|
||||
The only exception to the _no class variables rule_ is addition of a
|
||||
`has_backward` class variable. This variable is used to indicate whether
|
||||
the layer has a backward pass implemented (`True` when absent).
|
||||
|
||||
This is an example of a pure layer:
|
||||
|
||||
```python
|
||||
class SiluAndMul(nn.Module):
|
||||
# This layer does not implement backward.
|
||||
has_backward: bool = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
|
@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
|
||||
"kernels-community/activation" = ">=0.0.1"
|
||||
```
|
||||
|
||||
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
the locked revisions. The locked revision will be used when loading a kernel with
|
||||
`get_locked_kernel`:
|
||||
|
||||
@ -28,7 +28,7 @@ to `kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `kernel download .` in your
|
||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||
project directory. This will download the kernels to your local Hugging Face
|
||||
Hub cache.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.3.1"
|
||||
version = "0.4.4"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
@ -8,13 +8,13 @@ authors = [
|
||||
{ name = "David Holtz", email = "david@huggingface.co" },
|
||||
{ name = "Nicolas Patry", email = "nicolas@huggingface.co" },
|
||||
]
|
||||
license = { text = "Apache-2.0" }
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.9"
|
||||
dependencies = [
|
||||
"huggingface-hub>=0.26.3",
|
||||
"packaging>=24.2",
|
||||
"tomli>=2.0.1; python_version<'3.11'",
|
||||
"torch>=2.5",
|
||||
"huggingface_hub>=0.26.0,<1.0",
|
||||
"packaging>=20.0",
|
||||
"tomli>=2.0; python_version<'3.11'",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@ -27,8 +27,12 @@ dev = [
|
||||
"pytest >=8",
|
||||
# Whatever version is compatible with pytest.
|
||||
"pytest-benchmark",
|
||||
"torch >=2.5",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
torch = ["torch"]
|
||||
|
||||
[project.scripts]
|
||||
kernels = "kernels.cli:main"
|
||||
|
||||
|
@ -2,11 +2,14 @@ from kernels.layer import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
use_kernel_mapping,
|
||||
)
|
||||
from kernels.utils import (
|
||||
get_kernel,
|
||||
get_locked_kernel,
|
||||
has_kernel,
|
||||
install_kernel,
|
||||
load_kernel,
|
||||
)
|
||||
@ -14,10 +17,13 @@ from kernels.utils import (
|
||||
__all__ = [
|
||||
"get_kernel",
|
||||
"get_locked_kernel",
|
||||
"has_kernel",
|
||||
"load_kernel",
|
||||
"install_kernel",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
"LayerRepository",
|
||||
"Device",
|
||||
]
|
||||
|
@ -1,14 +1,18 @@
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Union
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
|
||||
from .utils import get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
@ -54,11 +58,26 @@ _KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextV
|
||||
)
|
||||
|
||||
|
||||
def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]):
|
||||
def use_kernel_mapping(
|
||||
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
|
||||
*,
|
||||
inherit_mapping: bool = True,
|
||||
):
|
||||
"""
|
||||
Context manager that sets a mapping for a duration of the context.
|
||||
|
||||
When `inherit_mapping` is set to `True` the current mapping will be
|
||||
extended by `mapping` inside the context. If it is `False`, only
|
||||
`mapping` is used inside the context.
|
||||
"""
|
||||
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
# Mappings always stack on previous mappings.
|
||||
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
|
||||
if inherit_mapping:
|
||||
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
|
||||
else:
|
||||
self.token = _KERNEL_MAPPING.set({})
|
||||
register_kernel_mapping(mapping)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
@ -112,11 +131,21 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
|
||||
fallback_forward = cls.forward
|
||||
|
||||
cached_forward: Dict[LayerRepository, Callable] = {}
|
||||
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if _DISABLE_KERNEL_MAPPING:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
needs_backward = self.training
|
||||
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||
if kernel is None:
|
||||
warnings.warn(
|
||||
"\n"
|
||||
f"No kernel mapping found for layer `{layer_name}`. "
|
||||
f"Check if the layer name matches one of the kernels in the mapping or add the kernel "
|
||||
f"you want to use to the mapping. Defaulting to original forward implementation."
|
||||
)
|
||||
if not use_fallback:
|
||||
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
@ -134,9 +163,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
# Short-circuit if we already loaded the layer.
|
||||
layer_forward = cached_forward.get(repo, None)
|
||||
if layer_forward is not None:
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
layer = cached_layer.get(repo, None)
|
||||
if layer is not None:
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
@ -152,10 +183,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
finally:
|
||||
cls.forward = orig_forward
|
||||
|
||||
layer_forward = layer.forward
|
||||
cached_forward[repo] = layer_forward
|
||||
cached_layer[repo] = layer
|
||||
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
|
||||
cls.forward = forward
|
||||
|
||||
@ -212,7 +244,8 @@ def _validate_layer(*, check_cls, cls):
|
||||
# ... or predefined member variables.
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||
if cls_members - torch_module_members != set():
|
||||
difference = cls_members - torch_module_members
|
||||
if difference != set() and difference != {"has_backward"}:
|
||||
raise TypeError("Layer must not contain additional members.")
|
||||
|
||||
# Check whether the forward signatures are similar.
|
||||
|
@ -4,6 +4,7 @@ import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
@ -12,29 +13,45 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
|
||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
|
||||
def _get_cache_dir() -> Optional[str]:
|
||||
"""Returns the kernels cache directory."""
|
||||
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
if cache_dir is not None:
|
||||
logging.warning(
|
||||
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
|
||||
)
|
||||
return cache_dir
|
||||
|
||||
return os.environ.get("KERNELS_CACHE", None)
|
||||
|
||||
|
||||
CACHE_DIR: Optional[str] = _get_cache_dir()
|
||||
|
||||
|
||||
def build_variant() -> str:
|
||||
import torch
|
||||
|
||||
if torch.version.cuda is None:
|
||||
raise AssertionError(
|
||||
"This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled."
|
||||
)
|
||||
if torch.version.cuda is not None:
|
||||
cuda_version = parse(torch.version.cuda)
|
||||
compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
|
||||
elif torch.version.hip is not None:
|
||||
rocm_version = parse(torch.version.hip.split("-")[0])
|
||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||
else:
|
||||
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
cuda_version = parse(torch.version.cuda)
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
cpu = platform.machine()
|
||||
os = platform.system().lower()
|
||||
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
|
||||
|
||||
|
||||
def universal_build_variant() -> str:
|
||||
@ -144,6 +161,29 @@ def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
(Torch version and compute framework).
|
||||
"""
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
if file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{universal_variant}/{package_name}/__init__.py",
|
||||
):
|
||||
return True
|
||||
|
||||
return file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{variant}/{package_name}/__init__.py",
|
||||
)
|
||||
|
||||
|
||||
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||
"""
|
||||
Get a pre-downloaded, locked kernel.
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel
|
||||
from kernels import get_kernel, has_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -36,6 +36,22 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_exists",
|
||||
[
|
||||
("kernels-community/activation", "main", True),
|
||||
("kernels-community/triton-layer-norm", "main", True),
|
||||
# Repo only contains Torch 2.4 kernels (and we don't
|
||||
# support/test against this version).
|
||||
("kernels-test/only-torch-2.4", "main", False),
|
||||
("google-bert/bert-base-uncased", "87565a309", False),
|
||||
],
|
||||
)
|
||||
def test_has_kernel(kernel_exists):
|
||||
repo_id, revision, kernel = kernel_exists
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||
|
@ -152,6 +152,25 @@ def test_mapping_contexts():
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
with use_kernel_mapping(extra_mapping2, inherit_mapping=False):
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
@ -184,3 +203,75 @@ def test_validate_kernel_layer():
|
||||
|
||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
def test_fallback_used_when_training():
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinear(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
linear = TorchLinear(32, 32).to("cuda")
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearImplicitBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
Reference in New Issue
Block a user