mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-23 14:59:08 +08:00
Compare commits
31 Commits
Author | SHA1 | Date | |
---|---|---|---|
ae43772a67 | |||
c02f88cd2a | |||
a3db6f437c | |||
5e938ede40 | |||
cf530c283a | |||
437f910336 | |||
6f1a6067c8 | |||
1d14abcef0 | |||
6fd2112e22 | |||
70f56ff856 | |||
7178b0b86c | |||
0bbf90a564 | |||
27d6ffcb80 | |||
f7bd21438b | |||
6174febb4b | |||
ff55bc201b | |||
3808108d62 | |||
c4a16ef462 | |||
9762794dd2 | |||
b7d6867c52 | |||
fbcd0f2ebd | |||
5af46eca94 | |||
747dd66876 | |||
920590a592 | |||
5208ac4be5 | |||
22eaba2826 | |||
9521ba79a0 | |||
9861a5bdef | |||
1c7c87c960 | |||
df45cf2795 | |||
cf0413efe5 |
19
.github/workflows/build_documentation.yaml
vendored
Normal file
19
.github/workflows/build_documentation.yaml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: Build documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- "docs/source/**"
|
||||
branches:
|
||||
- main
|
||||
- doc-builder*
|
||||
- v*-release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: kernels
|
||||
secrets:
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
18
.github/workflows/build_pr_documentation.yaml
vendored
Normal file
18
.github/workflows/build_pr_documentation.yaml
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
name: Build PR Documentation
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "docs/source/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: kernels
|
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -52,3 +52,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests
|
||||
|
||||
- name: Import check without torch
|
||||
run: |
|
||||
uv pip uninstall torch
|
||||
python -c "import kernels"
|
||||
|
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
13
README.md
13
README.md
@ -1,5 +1,16 @@
|
||||
# kernels
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
|
||||
<a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
|
||||
|
||||
</p>
|
||||
</div>
|
||||
<hr/>
|
||||
|
||||
The Kernel Hub allows Python libraries and applications to load compute
|
||||
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||
@ -45,7 +56,9 @@ the Hub.
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [Using layers](docs/layers.md)
|
||||
- [Locking kernel versions](docs/locking.md)
|
||||
- [Environment variables](docs/env.md)
|
||||
- [Using kernels in a Docker container](docs/docker.md)
|
||||
- [Kernel requirements](docs/kernel-requirements.md)
|
||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||
|
26
docs/source/_toctree.yml
Normal file
26
docs/source/_toctree.yml
Normal file
@ -0,0 +1,26 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: Introduction
|
||||
- local: installation
|
||||
title: Installation
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: basic_usage
|
||||
title: Basic Usage
|
||||
- local: layers
|
||||
title: Using Layers
|
||||
- local: locking
|
||||
title: Locking Kernel Versions
|
||||
- local: env
|
||||
title: Environment Variables
|
||||
title: Usage Guide
|
||||
- sections:
|
||||
- local: api/kernels
|
||||
title: Kernels
|
||||
- local: api/layers
|
||||
title: Layers
|
||||
title: API Reference
|
||||
- sections:
|
||||
- local: kernel_requirements
|
||||
title: Kernel Requirements
|
||||
title: Developer Guide
|
21
docs/source/api/kernels.md
Normal file
21
docs/source/api/kernels.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Kernels API Reference
|
||||
|
||||
## Main Functions
|
||||
|
||||
### get_kernel
|
||||
|
||||
[[autodoc]] kernels.get_kernel
|
||||
|
||||
### has_kernel
|
||||
|
||||
[[autodoc]] kernels.has_kernel
|
||||
|
||||
## Loading locked kernels
|
||||
|
||||
### load_kernel
|
||||
|
||||
[[autodoc]] kernels.load_kernel
|
||||
|
||||
### get_locked_kernel
|
||||
|
||||
[[autodoc]] kernels.get_locked_kernel
|
31
docs/source/api/layers.md
Normal file
31
docs/source/api/layers.md
Normal file
@ -0,0 +1,31 @@
|
||||
# Layers API Reference
|
||||
|
||||
## Making layers kernel-aware
|
||||
|
||||
### use_kernel_forward_from_hub
|
||||
|
||||
[[autodoc]] kernels.use_kernel_forward_from_hub
|
||||
|
||||
### replace_kernel_forward_from_hub
|
||||
|
||||
[[autodoc]] kernels.replace_kernel_forward_from_hub
|
||||
|
||||
## Registering kernel mappings
|
||||
|
||||
### use_kernel_mapping
|
||||
|
||||
[[autodoc]] kernels.use_kernel_mapping
|
||||
|
||||
### register_kernel_mapping
|
||||
|
||||
[[autodoc]] kernels.register_kernel_mapping
|
||||
|
||||
## Classes
|
||||
|
||||
### LayerRepository
|
||||
|
||||
[[autodoc]] kernels.LayerRepository
|
||||
|
||||
### Device
|
||||
|
||||
[[autodoc]] kernels.Device
|
34
docs/source/basic_usage.md
Normal file
34
docs/source/basic_usage.md
Normal file
@ -0,0 +1,34 @@
|
||||
# Basic Usage
|
||||
|
||||
## Loading Kernels
|
||||
|
||||
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from kernels import get_kernel
|
||||
|
||||
# Download optimized kernels from the Hugging Face hub
|
||||
activation = get_kernel("kernels-community/activation")
|
||||
|
||||
# Create a random tensor
|
||||
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
|
||||
|
||||
# Run the kernel
|
||||
y = torch.empty_like(x)
|
||||
activation.gelu_fast(y, x)
|
||||
|
||||
print(y)
|
||||
```
|
||||
|
||||
## Checking Kernel Availability
|
||||
|
||||
You can check if a specific kernel is available for your environment:
|
||||
|
||||
```python
|
||||
from kernels import has_kernel
|
||||
|
||||
# Check if kernel is available for current environment
|
||||
is_available = has_kernel("kernels-community/activation")
|
||||
print(f"Kernel available: {is_available}")
|
||||
```
|
10
docs/source/env.md
Normal file
10
docs/source/env.md
Normal file
@ -0,0 +1,10 @@
|
||||
# Environment variables
|
||||
|
||||
## `KERNELS_CACHE`
|
||||
|
||||
The directory to use as the local kernel cache. If not set, the cache
|
||||
of the `huggingface_hub` package is used.
|
||||
|
||||
## `DISABLE_KERNEL_MAPPING`
|
||||
|
||||
Disables kernel mappings for [`layers`](layers.md).
|
20
docs/source/index.md
Normal file
20
docs/source/index.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Kernels
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||
</div>
|
||||
|
||||
The Kernel Hub allows Python libraries and applications to load compute
|
||||
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||
packages in that they are made to be:
|
||||
|
||||
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
|
||||
- **Unique**: multiple versions of the same kernel can be loaded in the
|
||||
same Python process.
|
||||
- **Compatible**: kernels must support all recent versions of Python and
|
||||
the different PyTorch build configurations (various CUDA versions
|
||||
and C++ ABIs). Furthermore, older C library versions must be supported.
|
||||
|
||||
You can [search for kernels](https://huggingface.co/models?other=kernel) on
|
||||
the Hub.
|
16
docs/source/installation.md
Normal file
16
docs/source/installation.md
Normal file
@ -0,0 +1,16 @@
|
||||
# Installation
|
||||
|
||||
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
|
||||
|
||||
```bash
|
||||
pip install kernels
|
||||
```
|
||||
|
||||
# Using kernels in a Docker container
|
||||
|
||||
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
|
||||
|
||||
```bash
|
||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
||||
```
|
@ -38,6 +38,12 @@ as the repository (replacing `-` by `_`). For instance, kernels in the
|
||||
`build/<variant>/activation`. This directory
|
||||
must be a Python package with an `__init__.py` file.
|
||||
|
||||
## Versioning
|
||||
|
||||
Kernels are versioned on the Hub using Git tags. Version tags must be of
|
||||
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
||||
to resolve the version constraints.
|
||||
|
||||
## Native Python module
|
||||
|
||||
Kernels will typically contain a native Python module with precompiled
|
||||
@ -46,16 +52,31 @@ requirements:
|
||||
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
- Compatible with glibc 2.27 or later. This means that no symbols
|
||||
from later versions must be used. To archive this, the module should
|
||||
be built against this glibc version. **Warning:** libgcc must also be
|
||||
built against glibc 2.27 to avoid leaking symbols.
|
||||
- No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols
|
||||
must be static.
|
||||
- No dynamic library dependencies outside Torch or CUDA libraries
|
||||
installed as dependencies of Torch.
|
||||
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
||||
This means that the extension **must not** use symbols versions higher than:
|
||||
|
||||
(These requirements will be updated as new PyTorch versions are released.)
|
||||
- GLIBC 2.28
|
||||
- GLIBCXX 3.4.24
|
||||
- CXXABI 1.3.11
|
||||
- GCC 7.0.0
|
||||
|
||||
These requirement can be checked with the ABI checker (see below).
|
||||
|
||||
- No dynamic library dependencies outside:
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
|
||||
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
||||
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
||||
|
||||
```bash
|
||||
|
||||
$ cargo install kernel-abi-check
|
||||
$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
|
||||
🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
|
||||
✅ No compatibility issues found
|
||||
```
|
||||
|
||||
## Torch extension
|
||||
|
||||
@ -76,6 +97,87 @@ might use two different commits that happen to have the same version
|
||||
number. Git tags are not stable, so they do not provide a good way
|
||||
of guaranteeing uniqueness of the namespace.
|
||||
|
||||
## Layers
|
||||
|
||||
A kernel can provide layers in addition to kernel functions. A layer from
|
||||
the Hub can replace the `forward` method of an existing layer for a certain
|
||||
device type. This makes it possible to provide more performant kernels for
|
||||
existing layers. See the [layers documentation](layers.md) for more information
|
||||
on how to use layers.
|
||||
|
||||
### Writing layers
|
||||
|
||||
To make the extension of layers safe, the layers must fulfill the following
|
||||
requirements:
|
||||
|
||||
- The layers are subclasses of `torch.nn.Module`.
|
||||
- The layers are pure, meaning that they do not have their own state. This
|
||||
means that:
|
||||
- The layer must not define its own constructor.
|
||||
- The layer must not use class variables.
|
||||
- No other methods must be defined than `forward`.
|
||||
- The `forward` method has a signature that is compatible with the
|
||||
`forward` method that it is extending.
|
||||
|
||||
The only exception to the _no class variables rule_ is addition of a
|
||||
`has_backward` class variable. This variable is used to indicate whether
|
||||
the layer has a backward pass implemented (`True` when absent).
|
||||
|
||||
This is an example of a pure layer:
|
||||
|
||||
```python
|
||||
class SiluAndMul(nn.Module):
|
||||
# This layer does not implement backward.
|
||||
has_backward: bool = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.silu_and_mul(out, x)
|
||||
return out
|
||||
```
|
||||
|
||||
For some layers, the `forward` method has to use state from the adopting class.
|
||||
In these cases, we recommend to use type annotations to indicate what member
|
||||
variables are expected. For instance:
|
||||
|
||||
```python
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
weight: torch.Tensor
|
||||
variance_epsilon: float
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return rms_norm_fn(
|
||||
hidden_states,
|
||||
self.weight,
|
||||
bias=None,
|
||||
residual=None,
|
||||
eps=self.variance_epsilon,
|
||||
dropout_p=0.0,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
)
|
||||
```
|
||||
|
||||
This layer expects the adopting layer to have `weight` and `variance_epsilon`
|
||||
member variables and uses them in the `forward` method.
|
||||
|
||||
### Exporting layers
|
||||
|
||||
To accommodate portable loading, `layers` must be defined in the main
|
||||
`__init__.py` file. For example:
|
||||
|
||||
```python
|
||||
from . import layers
|
||||
|
||||
__all__ = [
|
||||
# ...
|
||||
"layers"
|
||||
# ...
|
||||
]
|
||||
```
|
||||
|
||||
## Python requirements
|
||||
|
||||
- Python code must be compatible with Python 3.9 and later.
|
79
docs/source/layers.md
Normal file
79
docs/source/layers.md
Normal file
@ -0,0 +1,79 @@
|
||||
# Layers
|
||||
|
||||
A kernel can provide layers in addition to kernel functions. A layer from
|
||||
the Hub can replace the `forward` method of an existing layer for a certain
|
||||
device type. This makes it possible to provide more performant kernels for
|
||||
existing layers.
|
||||
|
||||
See [Kernel requirements](kernel-requirements.md) for more information the
|
||||
requirements of Hub layers.
|
||||
|
||||
## Making a layer extensible with kernels from the hub
|
||||
|
||||
### Using a decorator
|
||||
|
||||
A layer can be made extensible with the `use_kernel_forward_from_hub`
|
||||
decorator. For example:
|
||||
|
||||
```python
|
||||
@use_kernel_forward_from_hub("SiluAndMul")
|
||||
class SiluAndMul(nn.Module):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
d = input.shape[-1] // 2
|
||||
return F.silu(input[..., :d]) * input[..., d:]
|
||||
```
|
||||
|
||||
The decorator changes the layer, so that other implementations of the `forward`
|
||||
method can be registered using the name `SiluAndMul`.
|
||||
|
||||
### External layers
|
||||
|
||||
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
|
||||
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
|
||||
|
||||
```python
|
||||
from somelibrary import SiluAndMul
|
||||
|
||||
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
|
||||
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
|
||||
section for more information.
|
||||
|
||||
**Warning:** we strongly recommend using layers with a decorator, since
|
||||
it signifies that the maintainer intends to keep the `forward` signature
|
||||
compatible with layers from the hub.
|
||||
|
||||
## Registering a hub kernel for a layer
|
||||
|
||||
Once a layer is made extensible, users can register hub kernels for it
|
||||
by name using the `register_kernel_mapping` function. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
This will register the kernel mapping in the current context, which is
|
||||
normally global. It is recommended to scope the mapping to where it is
|
||||
used with the `use_kernel_mapping` context manager:
|
||||
|
||||
```python
|
||||
with use_kernel_mapping(kernel_layer_mapping):
|
||||
# Use the layer for which the mapping is applied.
|
||||
...
|
||||
```
|
||||
|
||||
This ensures that the mapping is not active anymore outside the
|
||||
`with`-scope.
|
@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
|
||||
"kernels-community/activation" = ">=0.0.1"
|
||||
```
|
||||
|
||||
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
the locked revisions. The locked revision will be used when loading a kernel with
|
||||
`get_locked_kernel`:
|
||||
|
||||
@ -28,7 +28,7 @@ to `kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `kernel download .` in your
|
||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||
project directory. This will download the kernels to your local Hugging Face
|
||||
Hub cache.
|
||||
|
134
flake.lock
generated
Normal file
134
flake.lock
generated
Normal file
@ -0,0 +1,134 @@
|
||||
{
|
||||
"nodes": {
|
||||
"flake-compat": {
|
||||
"locked": {
|
||||
"lastModified": 1733328505,
|
||||
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_2": {
|
||||
"inputs": {
|
||||
"systems": "systems_2"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1737453259,
|
||||
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
|
||||
"owner": "danieldk",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "outlines-v0.1.4-tgi",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"tgi-nix": "tgi-nix"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_2": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"tgi-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"flake-utils": "flake-utils_2",
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1741617161,
|
||||
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "kernels-0.2.0",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
54
flake.nix
Normal file
54
flake.nix
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
inputs = {
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
tgi-nix,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
inherit (tgi-nix.lib) config;
|
||||
overlays = [
|
||||
tgi-nix.overlays.default
|
||||
];
|
||||
};
|
||||
in
|
||||
{
|
||||
formatter = pkgs.nixfmt-rfc-style;
|
||||
devShells = with pkgs; rec {
|
||||
default = mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
black
|
||||
mypy
|
||||
pyright
|
||||
ruff
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
huggingface-hub
|
||||
pytest
|
||||
pytest-benchmark
|
||||
torch
|
||||
venvShellHook
|
||||
]);
|
||||
|
||||
venvDir = "./.venv";
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
( python -m pip install --no-build-isolation --no-dependencies -e . )
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.2.1"
|
||||
version = "0.4.4"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
@ -8,13 +8,13 @@ authors = [
|
||||
{ name = "David Holtz", email = "david@huggingface.co" },
|
||||
{ name = "Nicolas Patry", email = "nicolas@huggingface.co" },
|
||||
]
|
||||
license = { text = "Apache-2.0" }
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.9"
|
||||
dependencies = [
|
||||
"huggingface-hub>=0.26.3",
|
||||
"packaging>=24.2",
|
||||
"tomli>=2.0.1; python_version<'3.11'",
|
||||
"torch>=2.5",
|
||||
"huggingface_hub>=0.26.0,<1.0",
|
||||
"packaging>=20.0",
|
||||
"tomli>=2.0; python_version<'3.11'",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@ -27,6 +27,12 @@ dev = [
|
||||
"pytest >=8",
|
||||
# Whatever version is compatible with pytest.
|
||||
"pytest-benchmark",
|
||||
"torch >=2.5",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
docs = [
|
||||
"hf-doc-builder",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@ -35,7 +41,6 @@ kernels = "kernels.cli:main"
|
||||
[project.entry-points."egg_info.writers"]
|
||||
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
".eggs",
|
||||
|
@ -1,3 +1,33 @@
|
||||
from kernels.utils import get_kernel, get_locked_kernel, install_kernel, load_kernel
|
||||
import importlib.metadata
|
||||
|
||||
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
|
||||
from kernels.layer import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
use_kernel_mapping,
|
||||
)
|
||||
from kernels.utils import (
|
||||
get_kernel,
|
||||
get_locked_kernel,
|
||||
has_kernel,
|
||||
install_kernel,
|
||||
load_kernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_kernel",
|
||||
"get_locked_kernel",
|
||||
"has_kernel",
|
||||
"load_kernel",
|
||||
"install_kernel",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
"LayerRepository",
|
||||
"Device",
|
||||
]
|
||||
|
||||
__version__ = importlib.metadata.version("kernels")
|
||||
|
308
src/kernels/layer.py
Normal file
308
src/kernels/layer.py
Normal file
@ -0,0 +1,308 @@
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
|
||||
from .utils import get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
type: str
|
||||
|
||||
# In the future we might add compute capabilities, etc.
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Device) and self.type == other.type
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerRepository:
|
||||
"""
|
||||
Repository and name of a layer.
|
||||
"""
|
||||
|
||||
layer_name: str = field(
|
||||
metadata={"help": "The name of the layer in the kernel repository."}
|
||||
)
|
||||
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
|
||||
revision: str = field(
|
||||
default="main", metadata={"help": "The revision of the layer."}
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, LayerRepository)
|
||||
and self.layer_name == other.layer_name
|
||||
and self.repo_id == other.repo_id
|
||||
and self.revision == other.revision
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.layer_name, self.repo_id, self.revision))
|
||||
|
||||
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
|
||||
"_KERNEL_MAPPING", default={}
|
||||
)
|
||||
|
||||
|
||||
def use_kernel_mapping(
|
||||
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
|
||||
*,
|
||||
inherit_mapping: bool = True,
|
||||
):
|
||||
"""
|
||||
Context manager that sets a kernel mapping for the duration of the context.
|
||||
|
||||
Args:
|
||||
mapping (`Dict[str, Dict[Union[Device, str], LayerRepository]]`):
|
||||
A mapping between layer names and their corresponding kernel repositories.
|
||||
inherit_mapping (`bool`, *optional*, defaults to `True`):
|
||||
The current mapping will be extended by `mapping` when set to `True`.
|
||||
When set to `False`, the current mapping will be replaced by `mapping`
|
||||
for the duration of the context.
|
||||
|
||||
Returns:
|
||||
`ContextManager`: Context manager that sets up the mapping.
|
||||
"""
|
||||
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
# Mappings always stack on previous mappings.
|
||||
if inherit_mapping:
|
||||
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
|
||||
else:
|
||||
self.token = _KERNEL_MAPPING.set({})
|
||||
register_kernel_mapping(mapping)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
_KERNEL_MAPPING.reset(self.token)
|
||||
|
||||
return ContextManager()
|
||||
|
||||
|
||||
def register_kernel_mapping(
|
||||
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
|
||||
):
|
||||
"""
|
||||
Register a mapping between a layer name the corresponding kernel to use, depending on the device.
|
||||
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
|
||||
|
||||
Args:
|
||||
mapping (`Dict[str, Dict[Union[Device, str], LayerRepository]]`):
|
||||
A mapping between layer names and their corresponding kernel repositories.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from kernels import LayerRepository, register_kernel_mapping
|
||||
|
||||
kernel_layer_mapping = {
|
||||
"LlamaRMSNorm": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="RmsNorm",
|
||||
revision="layers",
|
||||
),
|
||||
},
|
||||
}
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
"""
|
||||
# Merge with existing mappings.
|
||||
for new_kernel, new_device_repos in mapping.items():
|
||||
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
||||
for new_device, new_repo in new_device_repos.items():
|
||||
if isinstance(new_device, str):
|
||||
device_repo[Device(type=new_device)] = new_repo
|
||||
else:
|
||||
device_repo[new_device] = new_repo
|
||||
|
||||
|
||||
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
|
||||
"""
|
||||
Replace the forward function of a layer using a layer from the kernel hub.
|
||||
This function monkeypatches a layer, replacing the `forward` method
|
||||
of the layer with that of a layer from the hub. The replacement is done
|
||||
when a layer matching `layer_name` and device type is registered through
|
||||
[`register_layer_mapping`]. The device type is inferred from the first
|
||||
argument to `forward`.
|
||||
|
||||
Args:
|
||||
cls (`nn.Module`):
|
||||
The layer class to replace the forward function of.
|
||||
layer_name (`str`):
|
||||
The name to assign to the layer.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the fallback forward function if no kernel mapping
|
||||
is found. If set to `False`, a `ValueError` will be raised if no kernel
|
||||
mapping is found.
|
||||
"""
|
||||
|
||||
fallback_forward = cls.forward
|
||||
|
||||
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if _DISABLE_KERNEL_MAPPING:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
needs_backward = self.training
|
||||
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||
if kernel is None:
|
||||
warnings.warn(
|
||||
"\n"
|
||||
f"No kernel mapping found for layer `{layer_name}`. "
|
||||
f"Check if the layer name matches one of the kernels in the mapping or add the kernel "
|
||||
f"you want to use to the mapping. Defaulting to original forward implementation."
|
||||
)
|
||||
if not use_fallback:
|
||||
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
device = getattr(x, "device", None)
|
||||
if device is None:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
repo = kernel.get(Device(type=device.type))
|
||||
if repo is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
|
||||
)
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
# Short-circuit if we already loaded the layer.
|
||||
layer = cached_layer.get(repo, None)
|
||||
if layer is not None:
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
layer_name=repo.layer_name,
|
||||
revision=repo.revision,
|
||||
)
|
||||
|
||||
# We have to validate against the original signature.
|
||||
orig_forward = cls.forward
|
||||
try:
|
||||
cls.forward = fallback_forward
|
||||
_validate_layer(check_cls=cls, cls=layer)
|
||||
finally:
|
||||
cls.forward = orig_forward
|
||||
|
||||
cached_layer[repo] = layer
|
||||
|
||||
if needs_backward and not getattr(layer, "has_backward", True):
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
|
||||
cls.forward = forward
|
||||
|
||||
|
||||
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
|
||||
"""
|
||||
Replace the forward function of a layer using a layer from the kernel hub.
|
||||
|
||||
This decorator can be applied to a layer and replaces the forward method
|
||||
of the layer with that of a layer from the hub. The replacement is done
|
||||
when a layer matching `layer_name` and device type is registered through
|
||||
[`register_layer_mapping`]. The device type is inferred from the first
|
||||
argument to `forward`.
|
||||
|
||||
Args:
|
||||
layer_name (`str`):
|
||||
The name to assign to the layer.
|
||||
use_fallback (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the fallback forward function if no kernel mapping
|
||||
is found. If set to `False`, a `ValueError` will be raised if no kernel
|
||||
mapping is found.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from kernels import use_kernel_forward_from_hub
|
||||
|
||||
@use_kernel_forward_from_hub(layer_name="LlamaRMSNorm")
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
# Original forward implementation
|
||||
pass
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
|
||||
"""Get a layer from a kernel."""
|
||||
|
||||
kernel = get_kernel(repo_id, revision=revision)
|
||||
|
||||
if getattr(kernel, "layers", None) is None:
|
||||
raise ValueError(
|
||||
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
|
||||
)
|
||||
|
||||
layer = getattr(kernel.layers, layer_name, None)
|
||||
if layer is None:
|
||||
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
|
||||
return layer
|
||||
|
||||
|
||||
def _validate_layer(*, check_cls, cls):
|
||||
# The layer must have at least have the following properties: (1) it
|
||||
# must be stateless; (2) the forward signature should correspond to
|
||||
# the signature it is replacing; (3) forward should not call other
|
||||
# methods.
|
||||
|
||||
from torch import nn
|
||||
|
||||
if not issubclass(cls, nn.Module):
|
||||
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
||||
|
||||
# We verify statelessness by checking that the does not have its own
|
||||
# constructor (since the constructor could add member variables)...
|
||||
if cls.__init__ is not nn.Module.__init__:
|
||||
raise TypeError("Layer must not override nn.Module constructor.")
|
||||
|
||||
# ... or predefined member variables.
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||
difference = cls_members - torch_module_members
|
||||
if difference != set() and difference != {"has_backward"}:
|
||||
raise TypeError("Layer must not contain additional members.")
|
||||
|
||||
# Check whether the forward signatures are similar.
|
||||
params = inspect.signature(cls.forward).parameters
|
||||
ref_params = inspect.signature(check_cls.forward).parameters
|
||||
|
||||
if len(params) != len(ref_params):
|
||||
raise TypeError(
|
||||
"Forward signature does not match: different number of arguments."
|
||||
)
|
||||
|
||||
for param, ref_param in zip(params.values(), ref_params.values()):
|
||||
if param.kind != ref_param.kind:
|
||||
raise TypeError(
|
||||
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||
)
|
@ -4,6 +4,7 @@ import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
@ -12,29 +13,45 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
|
||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
|
||||
def _get_cache_dir() -> Optional[str]:
|
||||
"""Returns the kernels cache directory."""
|
||||
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
if cache_dir is not None:
|
||||
logging.warning(
|
||||
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
|
||||
)
|
||||
return cache_dir
|
||||
|
||||
return os.environ.get("KERNELS_CACHE", None)
|
||||
|
||||
|
||||
CACHE_DIR: Optional[str] = _get_cache_dir()
|
||||
|
||||
|
||||
def build_variant() -> str:
|
||||
import torch
|
||||
|
||||
if torch.version.cuda is None:
|
||||
raise AssertionError(
|
||||
"This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled."
|
||||
)
|
||||
if torch.version.cuda is not None:
|
||||
cuda_version = parse(torch.version.cuda)
|
||||
compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
|
||||
elif torch.version.hip is not None:
|
||||
rocm_version = parse(torch.version.hip.split("-")[0])
|
||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||
else:
|
||||
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
cuda_version = parse(torch.version.cuda)
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
cpu = platform.machine()
|
||||
os = platform.system().lower()
|
||||
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
|
||||
|
||||
|
||||
def universal_build_variant() -> str:
|
||||
@ -140,16 +157,88 @@ def install_kernel_all_variants(
|
||||
|
||||
|
||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
"""
|
||||
Load a kernel from the kernel hub.
|
||||
|
||||
This function downloads a kernel to the local Hugging Face Hub cache
|
||||
directory (if it was not downloaded before) and then loads the kernel.
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||
revision (branch, tag, or commit) to download.
|
||||
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from kernels import get_kernel
|
||||
kernel = get_kernel("username/my-kernel")
|
||||
result = kernel.kernel_function(input_data)
|
||||
```
|
||||
"""
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
|
||||
This function checks whether there exists a kernel build for the current
|
||||
environment (Torch version, compute framework and architecture).
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The kernel revision.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
`True` if a compatible kernel build exists for the current environment,
|
||||
`False` otherwise.
|
||||
"""
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
if file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{universal_variant}/{package_name}/__init__.py",
|
||||
):
|
||||
return True
|
||||
|
||||
return file_exists(
|
||||
repo_id,
|
||||
revision=revision,
|
||||
filename=f"build/{variant}/{package_name}/__init__.py",
|
||||
)
|
||||
|
||||
|
||||
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||
"""
|
||||
Get a pre-downloaded, locked kernel.
|
||||
Loads a pre-downloaded, locked kernel module from the local cache.
|
||||
|
||||
If `lockfile` is not specified, the lockfile will be loaded from the
|
||||
caller's package metadata.
|
||||
This function retrieves a kernel that was locked at a specific revision with
|
||||
`kernels lock <project>` and then downloaded with `kernels download <project>`.
|
||||
|
||||
This function will fail if the kernel was not locked or downloaded. If you want
|
||||
the kernel to be downloaded when it is not in the cache, use [`get_locked_kernel`]
|
||||
instead.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
lockfile (`Optional[Path]`, *optional*, defaults to `None`):
|
||||
Path to a lockfile containing the commit SHA for the kernel. If `None`,
|
||||
the lock information is automatically retrieved from the metadata of the
|
||||
calling package.
|
||||
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module corresponding to the locked version.
|
||||
"""
|
||||
if lockfile is None:
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
@ -194,7 +283,27 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
||||
|
||||
|
||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
|
||||
"""Get a kernel using a lock file."""
|
||||
"""
|
||||
Loads a locked kernel module.
|
||||
|
||||
This function retrieves a kernel that was locked at a specific revision with
|
||||
`kernels lock <project>`.
|
||||
|
||||
This function will download the locked kernel when it is not available in the
|
||||
cache. If you want loading to fail if the kernel is not in the cache, use
|
||||
[`load_kernel`] instead.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The Hub repository containing the kernel.
|
||||
lockfile (`Optional[Path]`, *optional*, defaults to `None`):
|
||||
Path to a lockfile containing the commit SHA for the kernel. If `None`,
|
||||
the lock information is automatically retrieved from the metadata of the
|
||||
calling package.
|
||||
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module corresponding to the locked version.
|
||||
"""
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
|
||||
if locked_sha is None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel
|
||||
from kernels import get_kernel, has_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -36,6 +36,22 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_exists",
|
||||
[
|
||||
("kernels-community/activation", "main", True),
|
||||
("kernels-community/triton-layer-norm", "main", True),
|
||||
# Repo only contains Torch 2.4 kernels (and we don't
|
||||
# support/test against this version).
|
||||
("kernels-test/only-torch-2.4", "main", False),
|
||||
("google-bert/bert-base-uncased", "87565a309", False),
|
||||
],
|
||||
)
|
||||
def test_has_kernel(kernel_exists):
|
||||
repo_id, revision, kernel = kernel_exists
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||
|
277
tests/test_layer.py
Normal file
277
tests/test_layer.py
Normal file
@ -0,0 +1,277 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
|
||||
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
},
|
||||
"SiluAndMulStringDevice": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
d = input.shape[-1] // 2
|
||||
return F.silu(input[..., :d]) * input[..., d:]
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("SiluAndMul")
|
||||
class SiluAndMulWithKernel(SiluAndMul):
|
||||
pass
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("SiluAndMulStringDevice")
|
||||
class SiluAndMulStringDevice(SiluAndMul):
|
||||
pass
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@use_kernel_forward_from_hub("ArgKind")
|
||||
class ArgKind(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
arg1,
|
||||
arg2,
|
||||
*,
|
||||
kwarg1,
|
||||
kwarg2=42,
|
||||
):
|
||||
return (arg1, arg2, kwarg1, kwarg2)
|
||||
|
||||
arg_kind = ArgKind()
|
||||
assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
|
||||
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_hub_forward(cls, device):
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
silu_and_mul = SiluAndMul()
|
||||
X = torch.randn((32, 64), device=device)
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = cls()
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
assert silu_and_mul.n_calls == 1
|
||||
if device == "cuda":
|
||||
assert silu_and_mul_with_kernel.n_calls == 0
|
||||
else:
|
||||
assert silu_and_mul_with_kernel.n_calls == 1
|
||||
|
||||
|
||||
def test_layer_fallback_works():
|
||||
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
|
||||
class SiluAndMulWithKernelFallback(SiluAndMul):
|
||||
pass
|
||||
|
||||
# Check that we don't raise an exception for a non-existing kernel.
|
||||
SiluAndMulWithKernelFallback()
|
||||
|
||||
|
||||
def test_mapping_contexts():
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
|
||||
|
||||
extra_mapping1 = {
|
||||
"TestKernel": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
with use_kernel_mapping(extra_mapping1):
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"TestKernel",
|
||||
}
|
||||
|
||||
extra_mapping2 = {
|
||||
"SiluAndMul": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-community/non-existing",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
with use_kernel_mapping(extra_mapping2):
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
with use_kernel_mapping(extra_mapping2, inherit_mapping=False):
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
"SiluAndMulStringDevice",
|
||||
}
|
||||
|
||||
|
||||
def test_validate_kernel_layer():
|
||||
class BadLayer(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.foo = 42
|
||||
|
||||
with pytest.raises(TypeError, match="not override"):
|
||||
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
||||
|
||||
class BadLayer2(nn.Module):
|
||||
foo: int = 42
|
||||
|
||||
with pytest.raises(TypeError, match="not contain additional members"):
|
||||
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
||||
|
||||
class BadLayer3(nn.Module):
|
||||
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
||||
|
||||
with pytest.raises(TypeError, match="different number of arguments"):
|
||||
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
||||
|
||||
class BadLayer4(nn.Module):
|
||||
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
def test_fallback_used_when_training():
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinear(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
linear = TorchLinear(32, 32).to("cuda")
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearImplicitBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
Reference in New Issue
Block a user