mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-24 07:27:05 +08:00
Compare commits
1 Commits
v0.7.0
...
fix-comman
| Author | SHA1 | Date | |
|---|---|---|---|
| 03a8662f7f |
120
.github/workflows/publish.yml
vendored
120
.github/workflows/publish.yml
vendored
@ -1,120 +0,0 @@
|
|||||||
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
|
|
||||||
|
|
||||||
on: push
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
name: Build distribution 📦
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: "3.9"
|
|
||||||
- name: Install pypa/build
|
|
||||||
run: >-
|
|
||||||
python3 -m
|
|
||||||
pip install
|
|
||||||
build
|
|
||||||
--user
|
|
||||||
- name: Build a binary wheel and a source tarball
|
|
||||||
run: python3 -m build
|
|
||||||
- name: Store the distribution packages
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: python-package-distributions
|
|
||||||
path: dist/
|
|
||||||
|
|
||||||
publish-to-pypi:
|
|
||||||
name: >-
|
|
||||||
Publish Python 🐍 distribution 📦 to PyPI
|
|
||||||
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
|
|
||||||
needs:
|
|
||||||
- build
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
environment:
|
|
||||||
name: pypi
|
|
||||||
url: https://pypi.org/p/kernels
|
|
||||||
permissions:
|
|
||||||
id-token: write # IMPORTANT: mandatory for trusted publishing
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Download all the dists
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: python-package-distributions
|
|
||||||
path: dist/
|
|
||||||
- name: Publish distribution 📦 to PyPI
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
|
|
||||||
github-release:
|
|
||||||
name: >-
|
|
||||||
Sign the Python 🐍 distribution 📦 with Sigstore
|
|
||||||
and upload them to GitHub Release
|
|
||||||
needs:
|
|
||||||
- publish-to-pypi
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: write # IMPORTANT: mandatory for making GitHub Releases
|
|
||||||
id-token: write # IMPORTANT: mandatory for sigstore
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Download all the dists
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: python-package-distributions
|
|
||||||
path: dist/
|
|
||||||
- name: Sign the dists with Sigstore
|
|
||||||
uses: sigstore/gh-action-sigstore-python@v3.0.0
|
|
||||||
with:
|
|
||||||
inputs: >-
|
|
||||||
./dist/*.tar.gz
|
|
||||||
./dist/*.whl
|
|
||||||
- name: Create GitHub Release
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ github.token }}
|
|
||||||
run: >-
|
|
||||||
gh release create
|
|
||||||
"$GITHUB_REF_NAME"
|
|
||||||
--repo "$GITHUB_REPOSITORY"
|
|
||||||
--notes ""
|
|
||||||
- name: Upload artifact signatures to GitHub Release
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ github.token }}
|
|
||||||
# Upload to GitHub Release using the `gh` CLI.
|
|
||||||
# `dist/` contains the built packages, and the
|
|
||||||
# sigstore-produced signatures and certificates.
|
|
||||||
run: >-
|
|
||||||
gh release upload
|
|
||||||
"$GITHUB_REF_NAME" dist/**
|
|
||||||
--repo "$GITHUB_REPOSITORY"
|
|
||||||
|
|
||||||
publish-to-testpypi:
|
|
||||||
name: Publish Python 🐍 distribution 📦 to TestPyPI
|
|
||||||
needs:
|
|
||||||
- build
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
environment:
|
|
||||||
name: testpypi
|
|
||||||
url: https://test.pypi.org/p/kernels
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
id-token: write # IMPORTANT: mandatory for trusted publishing
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Download all the dists
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: python-package-distributions
|
|
||||||
path: dist/
|
|
||||||
- name: Publish distribution 📦 to TestPyPI
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
repository-url: https://test.pypi.org/legacy/
|
|
||||||
skip-existing: true # Only upload when the version is unique.
|
|
||||||
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
|||||||
max-parallel: 4
|
max-parallel: 4
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10", "3.12"]
|
python-version: ["3.10", "3.12"]
|
||||||
torch-version: ["2.6.0", "2.7.0"]
|
torch-version: ["2.5.1", "2.6.0"]
|
||||||
|
|
||||||
env:
|
env:
|
||||||
UV_PYTHON_PREFERENCE: only-managed
|
UV_PYTHON_PREFERENCE: only-managed
|
||||||
@ -53,18 +53,6 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: uv run pytest tests
|
run: uv run pytest tests
|
||||||
|
|
||||||
- name: Check kernel conversion
|
|
||||||
run: |
|
|
||||||
uv pip install wheel
|
|
||||||
uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
|
|
||||||
uv pip install triton_layer_norm-0.0.1*.whl
|
|
||||||
uv run python -c "import triton_layer_norm"
|
|
||||||
|
|
||||||
- name: Check README generation
|
|
||||||
# For now, just checks that generation doesn't fail.
|
|
||||||
run: |
|
|
||||||
uv run kernels generate-readme kernels-community/triton-layer-norm
|
|
||||||
|
|
||||||
- name: Import check without torch
|
- name: Import check without torch
|
||||||
run: |
|
run: |
|
||||||
uv pip uninstall torch
|
uv pip uninstall torch
|
||||||
|
|||||||
@ -61,5 +61,4 @@ the Hub.
|
|||||||
- [Environment variables](docs/env.md)
|
- [Environment variables](docs/env.md)
|
||||||
- [Using kernels in a Docker container](docs/docker.md)
|
- [Using kernels in a Docker container](docs/docker.md)
|
||||||
- [Kernel requirements](docs/kernel-requirements.md)
|
- [Kernel requirements](docs/kernel-requirements.md)
|
||||||
- [Frequently Asked Questions](docs/faq.md)
|
|
||||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
|
|||||||
13
docs/faq.md
13
docs/faq.md
@ -1,13 +0,0 @@
|
|||||||
# FAQ
|
|
||||||
|
|
||||||
## Why is the kernelization step needed?
|
|
||||||
|
|
||||||
In earlier versions of `kernels`, a layer's `forward` was replaced by
|
|
||||||
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
|
|
||||||
new `forward` would dispatch to a kernel based on the device type,
|
|
||||||
whether a model was training, etc. However, this approach was
|
|
||||||
fundamentally incompatible with `torch.compile` since it relied
|
|
||||||
on data-dependent branching.
|
|
||||||
|
|
||||||
To avoid branching, we have to make dispatch decisions ahead of time,
|
|
||||||
which is what the `kernelize` function does.
|
|
||||||
@ -1,11 +1,8 @@
|
|||||||
# Kernel requirements
|
# Kernel requirements
|
||||||
|
|
||||||
Kernels on the Hub must fulfill the requirements outlined on this page. By
|
Kernels on the Hub must fulfill the requirements outlined on this page.
|
||||||
ensuring kernels are compliant, they can be used on a wide range of Linux
|
|
||||||
systems and Torch builds.
|
|
||||||
|
|
||||||
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
to build compliant kernels.
|
to build conforming kernels.
|
||||||
|
|
||||||
## Directory layout
|
## Directory layout
|
||||||
|
|
||||||
@ -13,21 +10,34 @@ A kernel repository on the Hub must contain a `build` directory. This
|
|||||||
directory contains build variants of a kernel in the form of directories
|
directory contains build variants of a kernel in the form of directories
|
||||||
following the template
|
following the template
|
||||||
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
|
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
|
||||||
For example `build/torch26-cxx98-cu118-x86_64-linux`.
|
For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently
|
||||||
|
recommended build variants are:
|
||||||
|
|
||||||
Each variant directory must contain a single directory with the same name
|
- `torch25-cxx11-cu118-x86_64-linux`
|
||||||
|
- `torch25-cxx11-cu121-x86_64-linux`
|
||||||
|
- `torch25-cxx11-cu124-x86_64-linux`
|
||||||
|
- `torch25-cxx98-cu118-x86_64-linux`
|
||||||
|
- `torch25-cxx98-cu121-x86_64-linux`
|
||||||
|
- `torch25-cxx98-cu124-x86_64-linux`
|
||||||
|
- `torch26-cxx11-cu118-x86_64-linux`
|
||||||
|
- `torch26-cxx11-cu124-x86_64-linux`
|
||||||
|
- `torch26-cxx11-cu126-x86_64-linux`
|
||||||
|
- `torch26-cxx98-cu118-x86_64-linux`
|
||||||
|
- `torch26-cxx98-cu124-x86_64-linux`
|
||||||
|
- `torch26-cxx98-cu126-x86_64-linux`
|
||||||
|
|
||||||
|
This list will be updated as new PyTorch versions are released. Kernels
|
||||||
|
that are in pure Python (e.g. Triton kernels) only need to provide a
|
||||||
|
single build variant:
|
||||||
|
|
||||||
|
- `torch-universal`
|
||||||
|
|
||||||
|
Each variant directory should contain a single directory with the same name
|
||||||
as the repository (replacing `-` by `_`). For instance, kernels in the
|
as the repository (replacing `-` by `_`). For instance, kernels in the
|
||||||
`kernels-community/activation` repository have a directories like
|
`kernels-community/activation` repository have a directories like
|
||||||
`build/<variant>/activation`. This directory
|
`build/<variant>/activation`. This directory
|
||||||
must be a Python package with an `__init__.py` file.
|
must be a Python package with an `__init__.py` file.
|
||||||
|
|
||||||
## Build variants
|
|
||||||
|
|
||||||
A kernel can be compliant for a specific compute framework (e.g. CUDA) or
|
|
||||||
architecture (e.g. x86_64). For compliance with a compute framework and
|
|
||||||
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
|
|
||||||
must be available for that combination.
|
|
||||||
|
|
||||||
## Versioning
|
## Versioning
|
||||||
|
|
||||||
Kernels are versioned on the Hub using Git tags. Version tags must be of
|
Kernels are versioned on the Hub using Git tags. Version tags must be of
|
||||||
@ -37,14 +47,8 @@ to resolve the version constraints.
|
|||||||
## Native Python module
|
## Native Python module
|
||||||
|
|
||||||
Kernels will typically contain a native Python module with precompiled
|
Kernels will typically contain a native Python module with precompiled
|
||||||
compute kernels and bindings. This module must fulfill the requirements
|
compute kernels and bindings. This module must fulfill the following
|
||||||
outlined in this section. For all operating systems, a kernel must not
|
requirements:
|
||||||
have dynamic library dependencies outside:
|
|
||||||
|
|
||||||
- Torch;
|
|
||||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
|
||||||
|
|
||||||
### Linux
|
|
||||||
|
|
||||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||||
for compatibility with Python 3.9 and later.
|
for compatibility with Python 3.9 and later.
|
||||||
@ -56,18 +60,12 @@ have dynamic library dependencies outside:
|
|||||||
- CXXABI 1.3.11
|
- CXXABI 1.3.11
|
||||||
- GCC 7.0.0
|
- GCC 7.0.0
|
||||||
|
|
||||||
These requirement can be checked with the ABI checker (see below).
|
These requirement can be checked with the ABI checker (see below).
|
||||||
|
|
||||||
### macOS
|
- No dynamic library dependencies outside:
|
||||||
|
|
||||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
- Torch;
|
||||||
for compatibility with Python 3.9 and later.
|
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||||
- macOS deployment target 15.0.
|
|
||||||
- Metal 3.0 (`-std=metal3.0`).
|
|
||||||
|
|
||||||
The ABI3 requirement can be checked with the ABI checker (see below).
|
|
||||||
|
|
||||||
### ABI checker
|
|
||||||
|
|
||||||
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
||||||
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
||||||
@ -121,12 +119,9 @@ requirements:
|
|||||||
- The `forward` method has a signature that is compatible with the
|
- The `forward` method has a signature that is compatible with the
|
||||||
`forward` method that it is extending.
|
`forward` method that it is extending.
|
||||||
|
|
||||||
There are two exceptions to the _no class variables rule_:
|
The only exception to the _no class variables rule_ is addition of a
|
||||||
|
`has_backward` class variable. This variable is used to indicate whether
|
||||||
1. The `has_backward` variable can be used to indicate whether the layer has
|
the layer has a backward pass implemented (`True` when absent).
|
||||||
a backward pass implemented (`True` when absent).
|
|
||||||
2. The `can_torch_compile` variable can be used to indicate whether the layer
|
|
||||||
supports `torch.compile` (`False` when absent).
|
|
||||||
|
|
||||||
This is an example of a pure layer:
|
This is an example of a pure layer:
|
||||||
|
|
||||||
|
|||||||
142
docs/layers.md
142
docs/layers.md
@ -23,93 +23,33 @@ class SiluAndMul(nn.Module):
|
|||||||
return F.silu(input[..., :d]) * input[..., d:]
|
return F.silu(input[..., :d]) * input[..., d:]
|
||||||
```
|
```
|
||||||
|
|
||||||
The decorator does not change the behavior of the class -- it annotates
|
The decorator changes the layer, so that other implementations of the `forward`
|
||||||
the class with the given name (here `SiluAndMul`). The `kernelize` function
|
method can be registered using the name `SiluAndMul`.
|
||||||
described below uses this name to look up kernels for the layer.
|
|
||||||
|
|
||||||
### External layers
|
### External layers
|
||||||
|
|
||||||
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
|
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
|
||||||
decorator can be made extensible using the `replace_kernel_forward_from_hub`
|
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
|
||||||
function:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from somelibrary import SiluAndMul
|
from somelibrary import SiluAndMul
|
||||||
|
|
||||||
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
|
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
|
||||||
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
|
||||||
|
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
|
||||||
|
section for more information.
|
||||||
|
|
||||||
**Warning:** we strongly recommend using layers with a decorator, since
|
**Warning:** we strongly recommend using layers with a decorator, since
|
||||||
it signifies that the maintainer intends to keep the `forward` signature
|
it signifies that the maintainer intends to keep the `forward` signature
|
||||||
compatible with layers from the hub.
|
compatible with layers from the hub.
|
||||||
|
|
||||||
## Kernelizing a model
|
|
||||||
|
|
||||||
A model will not use Hub kernels by default, even if it contains extensible
|
|
||||||
layers. To enable the use of Hub kernels in the model, it needs to be
|
|
||||||
'kernelized' using the `kernelize` function. This function traverses the
|
|
||||||
model graph and replaces the `forward` methods of extensible layers for which
|
|
||||||
Hub kernels are registered. Kernelize can be used as follows:
|
|
||||||
|
|
||||||
```python
|
|
||||||
model = MyModel(...)
|
|
||||||
model = kernelize(model, mode=Mode.INFERENCE)
|
|
||||||
```
|
|
||||||
|
|
||||||
The `mode` specifies that the model will be used in inference. Similarly,
|
|
||||||
you can ask `kernelize` to prepare the model for training:
|
|
||||||
|
|
||||||
```python
|
|
||||||
model = MyModel(...)
|
|
||||||
model = kernelize(model, mode=Mode.TRAINING)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note:** the `kernelize` function modifies the model in-place, the model
|
|
||||||
itself is returned as a convenience.
|
|
||||||
|
|
||||||
### Kernel device
|
|
||||||
|
|
||||||
Kernels can be registered per device type. For instance, separate `cuda` and
|
|
||||||
`metal` kernels could be registered for the name `SiluAndMul`. By default,
|
|
||||||
`kernelize` will try to infer the device type from the model's parameters.
|
|
||||||
You can pass the device type to `kernelize` if the device type cannot be
|
|
||||||
inferred (e.g. because the model has no parameters):
|
|
||||||
|
|
||||||
```python
|
|
||||||
model = MyModel(...)
|
|
||||||
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
|
|
||||||
```
|
|
||||||
|
|
||||||
### `torch.compile`
|
|
||||||
|
|
||||||
Not all Hub kernels support `torch.compile`. If you want to compile a model
|
|
||||||
after kernelizing it, you need to add this to the mode. You can use the
|
|
||||||
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
|
|
||||||
|
|
||||||
```python
|
|
||||||
model = MyModel(...)
|
|
||||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Fallback `forward`
|
|
||||||
|
|
||||||
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
|
||||||
kernel does not support backward passes or `torch.compile` respectively,
|
|
||||||
`kernenize` will fall back to the original, non-kernelized, layer. You
|
|
||||||
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
model = MyModel(...)
|
|
||||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
This can be useful if you want to guarantee that Hub kernels are used.
|
|
||||||
|
|
||||||
## Registering a hub kernel for a layer
|
## Registering a hub kernel for a layer
|
||||||
|
|
||||||
`kernelize` relies on kernel mappings to find Hub kernels for layers.
|
Once a layer is made extensible, users can register hub kernels for it
|
||||||
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
|
by name using the `register_kernel_mapping` function. For example:
|
||||||
the Hub. For example:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
kernel_layer_mapping = {
|
kernel_layer_mapping = {
|
||||||
@ -117,14 +57,11 @@ kernel_layer_mapping = {
|
|||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
repo_id="kernels-community/activation",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="SiluAndMul",
|
layer_name="SiluAndMul",
|
||||||
|
revision="layers",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
|
||||||
|
|
||||||
You can register such a mapping using `register_kernel_mapping`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
register_kernel_mapping(kernel_layer_mapping)
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -135,63 +72,8 @@ used with the `use_kernel_mapping` context manager:
|
|||||||
```python
|
```python
|
||||||
with use_kernel_mapping(kernel_layer_mapping):
|
with use_kernel_mapping(kernel_layer_mapping):
|
||||||
# Use the layer for which the mapping is applied.
|
# Use the layer for which the mapping is applied.
|
||||||
model = kernelize(model)
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
This ensures that the mapping is not active anymore outside the
|
This ensures that the mapping is not active anymore outside the
|
||||||
`with`-scope.
|
`with`-scope.
|
||||||
|
|
||||||
### Registering kernels for specific modes
|
|
||||||
|
|
||||||
You might want to register two different kernels for a particular layer,
|
|
||||||
where one kernel is optimized for a specific mode. You can do so by
|
|
||||||
registering layer repositories for specific modes. For example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
kernel_layer_mapping = {
|
|
||||||
"SiluAndMul": {
|
|
||||||
"cuda": {
|
|
||||||
Mode.INFERENCE: LayerRepository(
|
|
||||||
repo_id="kernels-community/activation-inference-optimized",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
),
|
|
||||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
|
||||||
repo_id="kernels-community/activation-training-optimized",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
|
|
||||||
layer is used when the `mode` passed to `kernelize` is
|
|
||||||
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
|
|
||||||
register a kernel to be used when the mode does not match any of the
|
|
||||||
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
|
|
||||||
so. For example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
kernel_layer_mapping = {
|
|
||||||
"SiluAndMul": {
|
|
||||||
"cuda": {
|
|
||||||
Mode.DEFAULT: LayerRepository(
|
|
||||||
repo_id="kernels-community/activation",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
),
|
|
||||||
Mode.INFERENCE: LayerRepository(
|
|
||||||
repo_id="kernels-community/activation-inference-optimized",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
),
|
|
||||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
|
||||||
repo_id="kernels-community/activation-training-optimized",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
In this case, modes other than `Mode.INFERENCE` and
|
|
||||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
|
|
||||||
`kernels-community/activation`.
|
|
||||||
|
|||||||
55
flake.lock
generated
55
flake.lock
generated
@ -51,38 +51,18 @@
|
|||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"hf-nix": {
|
|
||||||
"inputs": {
|
|
||||||
"flake-compat": "flake-compat",
|
|
||||||
"flake-utils": "flake-utils_2",
|
|
||||||
"nixpkgs": "nixpkgs"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1750775451,
|
|
||||||
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
|
|
||||||
"owner": "huggingface",
|
|
||||||
"repo": "hf-nix",
|
|
||||||
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "huggingface",
|
|
||||||
"repo": "hf-nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1747820358,
|
"lastModified": 1737453259,
|
||||||
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"ref": "cudatoolkit-12.9-kernel-builder",
|
"ref": "outlines-v0.1.4-tgi",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
@ -90,11 +70,11 @@
|
|||||||
"root": {
|
"root": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"flake-utils": "flake-utils",
|
"flake-utils": "flake-utils",
|
||||||
"hf-nix": "hf-nix",
|
|
||||||
"nixpkgs": [
|
"nixpkgs": [
|
||||||
"hf-nix",
|
"tgi-nix",
|
||||||
"nixpkgs"
|
"nixpkgs"
|
||||||
]
|
],
|
||||||
|
"tgi-nix": "tgi-nix"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"systems": {
|
"systems": {
|
||||||
@ -126,6 +106,27 @@
|
|||||||
"repo": "default",
|
"repo": "default",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"tgi-nix": {
|
||||||
|
"inputs": {
|
||||||
|
"flake-compat": "flake-compat",
|
||||||
|
"flake-utils": "flake-utils_2",
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1741617161,
|
||||||
|
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
|
||||||
|
"owner": "huggingface",
|
||||||
|
"repo": "text-generation-inference-nix",
|
||||||
|
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "huggingface",
|
||||||
|
"ref": "kernels-0.2.0",
|
||||||
|
"repo": "text-generation-inference-nix",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"root": "root",
|
"root": "root",
|
||||||
|
|||||||
15
flake.nix
15
flake.nix
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
inputs = {
|
inputs = {
|
||||||
hf-nix.url = "github:huggingface/hf-nix";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
|
||||||
nixpkgs.follows = "hf-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
};
|
};
|
||||||
outputs =
|
outputs =
|
||||||
@ -9,21 +9,21 @@
|
|||||||
self,
|
self,
|
||||||
nixpkgs,
|
nixpkgs,
|
||||||
flake-utils,
|
flake-utils,
|
||||||
hf-nix,
|
tgi-nix,
|
||||||
}:
|
}:
|
||||||
flake-utils.lib.eachDefaultSystem (
|
flake-utils.lib.eachDefaultSystem (
|
||||||
system:
|
system:
|
||||||
let
|
let
|
||||||
pkgs = import nixpkgs {
|
pkgs = import nixpkgs {
|
||||||
inherit system;
|
inherit system;
|
||||||
config = hf-nix.lib.config system;
|
inherit (tgi-nix.lib) config;
|
||||||
overlays = [
|
overlays = [
|
||||||
hf-nix.overlays.default
|
tgi-nix.overlays.default
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
formatter = pkgs.nixfmt-tree;
|
formatter = pkgs.nixfmt-rfc-style;
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
default = mkShell {
|
default = mkShell {
|
||||||
buildInputs =
|
buildInputs =
|
||||||
@ -34,13 +34,10 @@
|
|||||||
ruff
|
ruff
|
||||||
]
|
]
|
||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
docutils
|
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
pytest
|
pytest
|
||||||
pytest-benchmark
|
pytest-benchmark
|
||||||
pyyaml
|
|
||||||
torch
|
torch
|
||||||
types-pyyaml
|
|
||||||
venvShellHook
|
venvShellHook
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "kernels"
|
name = "kernels"
|
||||||
version = "0.7.0"
|
version = "0.4.4"
|
||||||
description = "Download compute kernels"
|
description = "Download compute kernels"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||||
@ -14,7 +14,6 @@ requires-python = ">= 3.9"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"huggingface_hub>=0.26.0,<1.0",
|
"huggingface_hub>=0.26.0,<1.0",
|
||||||
"packaging>=20.0",
|
"packaging>=20.0",
|
||||||
"pyyaml>=6",
|
|
||||||
"tomli>=2.0; python_version<'3.11'",
|
"tomli>=2.0; python_version<'3.11'",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -24,12 +23,11 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"mypy >= 1.15.0",
|
"mypy == 1.14.1",
|
||||||
"pytest >=8",
|
"pytest >=8",
|
||||||
# Whatever version is compatible with pytest.
|
# Whatever version is compatible with pytest.
|
||||||
"pytest-benchmark",
|
"pytest-benchmark",
|
||||||
"torch >=2.5",
|
"torch >=2.5",
|
||||||
"types-pyyaml"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@ -1,4 +0,0 @@
|
|||||||
[pytest]
|
|
||||||
markers =
|
|
||||||
darwin_only: marks tests that should only run on macOS
|
|
||||||
linux_only: marks tests that should only run on Linux
|
|
||||||
@ -1,8 +1,6 @@
|
|||||||
from kernels.layer import (
|
from kernels.layer import (
|
||||||
Device,
|
Device,
|
||||||
LayerRepository,
|
LayerRepository,
|
||||||
Mode,
|
|
||||||
kernelize,
|
|
||||||
register_kernel_mapping,
|
register_kernel_mapping,
|
||||||
replace_kernel_forward_from_hub,
|
replace_kernel_forward_from_hub,
|
||||||
use_kernel_forward_from_hub,
|
use_kernel_forward_from_hub,
|
||||||
@ -10,7 +8,6 @@ from kernels.layer import (
|
|||||||
)
|
)
|
||||||
from kernels.utils import (
|
from kernels.utils import (
|
||||||
get_kernel,
|
get_kernel,
|
||||||
get_local_kernel,
|
|
||||||
get_locked_kernel,
|
get_locked_kernel,
|
||||||
has_kernel,
|
has_kernel,
|
||||||
install_kernel,
|
install_kernel,
|
||||||
@ -18,18 +15,15 @@ from kernels.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Device",
|
|
||||||
"LayerRepository",
|
|
||||||
"Mode",
|
|
||||||
"get_kernel",
|
"get_kernel",
|
||||||
"get_local_kernel",
|
|
||||||
"get_locked_kernel",
|
"get_locked_kernel",
|
||||||
"has_kernel",
|
"has_kernel",
|
||||||
"install_kernel",
|
|
||||||
"kernelize",
|
|
||||||
"load_kernel",
|
"load_kernel",
|
||||||
"register_kernel_mapping",
|
"install_kernel",
|
||||||
"replace_kernel_forward_from_hub",
|
|
||||||
"use_kernel_forward_from_hub",
|
"use_kernel_forward_from_hub",
|
||||||
"use_kernel_mapping",
|
"use_kernel_mapping",
|
||||||
|
"register_kernel_mapping",
|
||||||
|
"replace_kernel_forward_from_hub",
|
||||||
|
"LayerRepository",
|
||||||
|
"Device",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,751 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
# Re pattern to catch things inside ` ` in :obj:`thing`.
|
|
||||||
_re_obj = re.compile(r":obj:`([^`]+)`")
|
|
||||||
# Re pattern to catch things inside ` ` in :math:`thing`.
|
|
||||||
_re_math = re.compile(r":math:`([^`]+)`")
|
|
||||||
# Re pattern to catch things between single backquotes.
|
|
||||||
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
|
|
||||||
# Re pattern to catch things between double backquotes.
|
|
||||||
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
|
|
||||||
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
|
|
||||||
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_rst_formatting(text):
|
|
||||||
"""
|
|
||||||
Convert rst syntax for formatting to markdown in a given text.
|
|
||||||
"""
|
|
||||||
# Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
|
|
||||||
# (to not be caught by the italic conversion).
|
|
||||||
text = _re_func_class.sub(r"[``\1``]", text)
|
|
||||||
# Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
|
|
||||||
# (to not be caught by the italic conversion).
|
|
||||||
text = _re_obj.sub(r"``\1``", text)
|
|
||||||
# Remove :math: markers.
|
|
||||||
text = _re_math.sub(r"\\\\(\1\\\\)", text)
|
|
||||||
# Convert content in single backquotes to italic.
|
|
||||||
text = _re_single_backquotes.sub(r"\1*\2*\3", text)
|
|
||||||
# Convert content in double backquotes to single backquotes.
|
|
||||||
text = _re_double_backquotes.sub(r"\1`\2`\3", text)
|
|
||||||
# Remove remaining ::
|
|
||||||
text = re.sub(r"::\n", "", text)
|
|
||||||
|
|
||||||
# Remove new lines inside blocks in backsticks as they will be kept.
|
|
||||||
lines = text.split("\n")
|
|
||||||
in_code = False
|
|
||||||
text = None
|
|
||||||
for line in lines:
|
|
||||||
if in_code:
|
|
||||||
splits = line.split("`")
|
|
||||||
in_code = len(splits) > 1 and len(splits) % 2 == 1
|
|
||||||
if len(splits) == 1:
|
|
||||||
# Some forgotten lone backstick
|
|
||||||
text += "\n" + line
|
|
||||||
else:
|
|
||||||
text += " " + line.lstrip()
|
|
||||||
else:
|
|
||||||
if text is not None:
|
|
||||||
text += "\n" + line
|
|
||||||
else:
|
|
||||||
text = line
|
|
||||||
splits = line.split("`")
|
|
||||||
in_code = len(splits) % 2 == 0
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
# Re pattern to catch description and url in links of the form `description <url>`_.
|
|
||||||
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
|
|
||||||
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
|
|
||||||
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
|
|
||||||
# Re pattern to catch reference in links of the form :doc:`reference`.
|
|
||||||
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
|
|
||||||
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
|
|
||||||
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
|
|
||||||
# Re pattern to catch reference in links of the form :ref:`reference`.
|
|
||||||
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
|
|
||||||
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
|
|
||||||
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_rst_links(text, page_info):
|
|
||||||
"""
|
|
||||||
Convert the rst links in text to markdown.
|
|
||||||
"""
|
|
||||||
if "package_name" not in page_info:
|
|
||||||
raise ValueError("`page_info` must contain at least the package_name.")
|
|
||||||
package_name = page_info["package_name"]
|
|
||||||
version = page_info.get("version", "main")
|
|
||||||
language = page_info.get("language", "en")
|
|
||||||
no_prefix = page_info.get("no_prefix", False)
|
|
||||||
|
|
||||||
prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
|
|
||||||
# Links of the form :doc:`page`
|
|
||||||
text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
|
|
||||||
# Links of the form :doc:`text <page>`
|
|
||||||
text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
|
|
||||||
|
|
||||||
if "page" in page_info and not no_prefix:
|
|
||||||
page = str(page_info["page"])
|
|
||||||
if page.endswith(".html"):
|
|
||||||
page = page[:-5]
|
|
||||||
prefix = f"{prefix}{page}"
|
|
||||||
else:
|
|
||||||
prefix = ""
|
|
||||||
# Refs of the form :ref:`page`
|
|
||||||
text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
|
|
||||||
# Refs of the form :ref:`text <page>`
|
|
||||||
text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
|
|
||||||
|
|
||||||
# Links with a prefix
|
|
||||||
# TODO: when it exists, use the API to deal with prefix links properly.
|
|
||||||
prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
|
|
||||||
text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
|
|
||||||
# Other links
|
|
||||||
text = _re_links.sub(r"[\1](\2)", text)
|
|
||||||
# Relative links or Transformers links need to remove the .html
|
|
||||||
if (
|
|
||||||
"(https://https://huggingface.co/" in text
|
|
||||||
or re.search(r"\(\.+/", text) is not None
|
|
||||||
):
|
|
||||||
text = text.replace(".html", "")
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
# Re pattern that catches examples blocks of the form `Example::`.
|
|
||||||
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
|
|
||||||
# Re pattern that catches rst blocks of the form `.. block_name::`.
|
|
||||||
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
|
|
||||||
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
|
|
||||||
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
|
|
||||||
|
|
||||||
|
|
||||||
def is_empty_line(line):
|
|
||||||
return len(line) == 0 or line.isspace()
|
|
||||||
|
|
||||||
|
|
||||||
def find_indent(line):
|
|
||||||
"""
|
|
||||||
Returns the number of spaces that start a line indent.
|
|
||||||
"""
|
|
||||||
search = re.search(r"^(\s*)(?:\S|$)", line)
|
|
||||||
if search is None:
|
|
||||||
return 0
|
|
||||||
return len(search.groups()[0])
|
|
||||||
|
|
||||||
|
|
||||||
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_special_chars(text):
|
|
||||||
"""
|
|
||||||
Converts { and < that have special meanings in MDX.
|
|
||||||
"""
|
|
||||||
text = text.replace("{", "&lcub;")
|
|
||||||
# We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
|
|
||||||
text = re.sub(
|
|
||||||
r"<(img|br|hr|Youtube)", r"LTHTML\1", text
|
|
||||||
) # html void elements with no closing counterpart
|
|
||||||
_re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
|
|
||||||
while _re_lt_html.search(text):
|
|
||||||
text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
|
|
||||||
text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&lt;\2", text)
|
|
||||||
text = text.replace("LTHTML", "<")
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def parse_options(block_content):
|
|
||||||
"""
|
|
||||||
Parses the option in some rst block content.
|
|
||||||
"""
|
|
||||||
block_lines = block_content.split("\n")
|
|
||||||
block_indent = find_indent(block_lines[0])
|
|
||||||
current_option = None
|
|
||||||
result = {}
|
|
||||||
for line in block_lines:
|
|
||||||
if _re_rst_option.search(line) is not None:
|
|
||||||
current_option, value = _re_rst_option.search(line).groups()
|
|
||||||
result[current_option] = value.lstrip()
|
|
||||||
elif find_indent(line) > block_indent:
|
|
||||||
result[current_option] += " " + line.lstrip()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def apply_min_indent(text, min_indent):
|
|
||||||
"""
|
|
||||||
Make sure all lines in a text are have a minimum indentation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (`str`): The text to treat.
|
|
||||||
min_indent (`int`): The minimal indentation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`str`: The processed text.
|
|
||||||
"""
|
|
||||||
lines = text.split("\n")
|
|
||||||
idx = 0
|
|
||||||
while idx < len(lines):
|
|
||||||
if is_empty_line(lines[idx]):
|
|
||||||
idx += 1
|
|
||||||
continue
|
|
||||||
indent = find_indent(lines[idx])
|
|
||||||
if indent < min_indent:
|
|
||||||
while idx < len(lines) and (
|
|
||||||
find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
|
|
||||||
):
|
|
||||||
if not is_empty_line(lines[idx]):
|
|
||||||
lines[idx] = " " * (min_indent - indent) + lines[idx]
|
|
||||||
idx += 1
|
|
||||||
else:
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_rst_blocks(text, page_info):
|
|
||||||
"""
|
|
||||||
Converts rst special blocks (examples, notes) into MDX.
|
|
||||||
"""
|
|
||||||
if "package_name" not in page_info:
|
|
||||||
raise ValueError("`page_info` must contain at least the package_name.")
|
|
||||||
package_name = page_info["package_name"]
|
|
||||||
version = page_info.get("version", "main")
|
|
||||||
language = page_info.get("language", "en")
|
|
||||||
|
|
||||||
lines = text.split("\n")
|
|
||||||
idx = 0
|
|
||||||
new_lines = []
|
|
||||||
while idx < len(lines):
|
|
||||||
block_type = None
|
|
||||||
block_info = None
|
|
||||||
if _re_block.search(lines[idx]) is not None:
|
|
||||||
block_type = _re_block.search(lines[idx]).groups()[0]
|
|
||||||
if _re_block_info.search(lines[idx]) is not None:
|
|
||||||
block_info = _re_block_info.search(lines[idx]).groups()[0]
|
|
||||||
elif _re_example.search(lines[idx]) is not None:
|
|
||||||
block_type = "code-block-example"
|
|
||||||
block_info = "python"
|
|
||||||
example_name = _re_example.search(lines[idx]).groups()[0]
|
|
||||||
new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
|
|
||||||
elif lines[idx].strip() == "..":
|
|
||||||
block_type = "comment"
|
|
||||||
elif lines[idx].strip() == "::":
|
|
||||||
block_type = "code-block"
|
|
||||||
|
|
||||||
if block_type is not None:
|
|
||||||
block_indent = find_indent(lines[idx])
|
|
||||||
# Find the next nonempty line
|
|
||||||
idx += 1
|
|
||||||
while idx < len(lines) and is_empty_line(lines[idx]):
|
|
||||||
idx += 1
|
|
||||||
# Grab the indent of the return line, this block will stop when we unindent under it (or has already)
|
|
||||||
example_indent = (
|
|
||||||
find_indent(lines[idx]) if idx < len(lines) else block_indent
|
|
||||||
)
|
|
||||||
|
|
||||||
if example_indent == block_indent:
|
|
||||||
block_content = ""
|
|
||||||
else:
|
|
||||||
block_lines = []
|
|
||||||
while idx < len(lines) and (
|
|
||||||
is_empty_line(lines[idx])
|
|
||||||
or find_indent(lines[idx]) >= example_indent
|
|
||||||
):
|
|
||||||
block_lines.append(lines[idx][example_indent:])
|
|
||||||
idx += 1
|
|
||||||
block_content = "\n".join(block_lines)
|
|
||||||
|
|
||||||
if block_type in ["code", "code-block"]:
|
|
||||||
prefix = "```" if block_info is None else f"```{block_info}"
|
|
||||||
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
|
|
||||||
elif block_type == "code-block-example":
|
|
||||||
prefix = f"<example>```{block_info}"
|
|
||||||
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
|
|
||||||
elif block_type == "note":
|
|
||||||
new_lines.append(
|
|
||||||
apply_min_indent(
|
|
||||||
f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif block_type == "warning":
|
|
||||||
new_lines.append(
|
|
||||||
apply_min_indent(
|
|
||||||
"<Tip warning={true}>\n\n"
|
|
||||||
+ f"{block_content.strip()}\n\n</Tip>\n",
|
|
||||||
block_indent,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif block_type == "raw":
|
|
||||||
new_lines.append(block_content.strip() + "\n")
|
|
||||||
elif block_type == "math":
|
|
||||||
new_lines.append(f"$${block_content.strip()}$$\n")
|
|
||||||
elif block_type == "comment":
|
|
||||||
new_lines.append(f"<!--{block_content.strip()}\n-->\n")
|
|
||||||
elif block_type == "autofunction":
|
|
||||||
if block_info is not None:
|
|
||||||
new_lines.append(f"[[autodoc]] {block_info}\n")
|
|
||||||
elif block_type == "autoclass":
|
|
||||||
if block_info is not None:
|
|
||||||
block = f"[[autodoc]] {block_info}\n"
|
|
||||||
options = parse_options(block_content)
|
|
||||||
if "special-members" in options:
|
|
||||||
special_members = options["special-members"].split(", ")
|
|
||||||
for special_member in special_members:
|
|
||||||
block += f" - {special_member}\n"
|
|
||||||
if "members" in options:
|
|
||||||
members = options["members"]
|
|
||||||
if len(members) == 0:
|
|
||||||
block += " - all\n"
|
|
||||||
else:
|
|
||||||
for member in members.split(", "):
|
|
||||||
block += f" - {member}\n"
|
|
||||||
new_lines.append(block)
|
|
||||||
elif block_type == "image":
|
|
||||||
options = parse_options(block_content)
|
|
||||||
target = options.pop("target", None)
|
|
||||||
if block_info is not None:
|
|
||||||
options["src"] = block_info
|
|
||||||
else:
|
|
||||||
if target is None:
|
|
||||||
raise ValueError("Image source not defined.")
|
|
||||||
options["src"] = target
|
|
||||||
# Adapt path
|
|
||||||
options["src"] = options["src"].replace(
|
|
||||||
"/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
|
|
||||||
)
|
|
||||||
html_code = " ".join(
|
|
||||||
[f'{key}="{value}"' for key, value in options.items()]
|
|
||||||
)
|
|
||||||
new_lines.append(f"<img {html_code}/>\n")
|
|
||||||
|
|
||||||
else:
|
|
||||||
new_lines.append(
|
|
||||||
f"{block_type},{block_info}\n{block_content.rstrip()}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
new_lines.append(lines[idx])
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
return "\n".join(new_lines)
|
|
||||||
|
|
||||||
|
|
||||||
# Re pattern that catches rst args blocks of the form `Parameters:`.
|
|
||||||
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
|
|
||||||
# Re pattern that catches return blocks of the form `Return:`.
|
|
||||||
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
|
|
||||||
|
|
||||||
|
|
||||||
def split_return_line(line):
|
|
||||||
"""
|
|
||||||
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
|
|
||||||
"""
|
|
||||||
splits_on_colon = line.split(":")
|
|
||||||
idx = 1
|
|
||||||
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
|
|
||||||
idx += 2
|
|
||||||
if idx >= len(splits_on_colon):
|
|
||||||
if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
|
|
||||||
return line, ""
|
|
||||||
return None, line
|
|
||||||
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
|
|
||||||
|
|
||||||
|
|
||||||
def split_raise_line(line):
|
|
||||||
"""
|
|
||||||
Split the raise line with format `SomeError some doc`.
|
|
||||||
"""
|
|
||||||
splits_on_colon = line.strip().split(" ")
|
|
||||||
error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
|
|
||||||
if error_type and error_type[-1] == ":":
|
|
||||||
error_type = error_type[:-1]
|
|
||||||
return error_type, doc
|
|
||||||
|
|
||||||
|
|
||||||
def split_arg_line(line):
|
|
||||||
"""
|
|
||||||
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
|
|
||||||
"""
|
|
||||||
splits_on_colon = line.split(":")
|
|
||||||
idx = 1
|
|
||||||
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
|
|
||||||
idx += 2
|
|
||||||
if idx >= len(splits_on_colon):
|
|
||||||
return line, ""
|
|
||||||
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidRstDocstringError(ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
_re_parameters = re.compile(
|
|
||||||
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
|
|
||||||
)
|
|
||||||
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_rst_docstring(docstring):
|
|
||||||
"""
|
|
||||||
Parses a docstring written in rst, in particular the list of arguments and the return type.
|
|
||||||
"""
|
|
||||||
lines = docstring.split("\n")
|
|
||||||
idx = 0
|
|
||||||
while idx < len(lines):
|
|
||||||
# Parameters section
|
|
||||||
if _re_args.search(lines[idx]) is not None:
|
|
||||||
# Title of the section.
|
|
||||||
lines[idx] = "<parameters>\n"
|
|
||||||
# Find the next nonempty line
|
|
||||||
idx += 1
|
|
||||||
while is_empty_line(lines[idx]):
|
|
||||||
idx += 1
|
|
||||||
# Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
|
|
||||||
# Returns or Raises block.
|
|
||||||
param_indent = find_indent(lines[idx])
|
|
||||||
while (
|
|
||||||
idx < len(lines)
|
|
||||||
and find_indent(lines[idx]) == param_indent
|
|
||||||
and _re_returns.search(lines[idx]) is None
|
|
||||||
):
|
|
||||||
intro, doc = split_arg_line(lines[idx])
|
|
||||||
# Line starting with a > after indent indicate a "section title" in the parameters.
|
|
||||||
if intro.lstrip().startswith(">"):
|
|
||||||
lines[idx] = intro.lstrip()
|
|
||||||
else:
|
|
||||||
lines[idx] = (
|
|
||||||
re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
|
|
||||||
)
|
|
||||||
idx += 1
|
|
||||||
while idx < len(lines) and (
|
|
||||||
is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
|
|
||||||
):
|
|
||||||
idx += 1
|
|
||||||
lines.insert(idx, "</parameters>\n")
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# Returns section
|
|
||||||
elif _re_returns.search(lines[idx]) is not None:
|
|
||||||
# tag is either `return` or `yield`
|
|
||||||
tag = _re_returns.match(lines[idx]).group(1).lower()
|
|
||||||
# Title of the section.
|
|
||||||
lines[idx] = f"<{tag}s>\n"
|
|
||||||
# Find the next nonempty line
|
|
||||||
idx += 1
|
|
||||||
while is_empty_line(lines[idx]):
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# Grab the indent of the return line, this block will stop when we unindent under it.
|
|
||||||
return_indent = find_indent(lines[idx])
|
|
||||||
raised_errors = []
|
|
||||||
# The line may contain the return type.
|
|
||||||
if tag in ["return", "yield"]:
|
|
||||||
return_type, return_description = split_return_line(lines[idx])
|
|
||||||
lines[idx] = return_description
|
|
||||||
idx += 1
|
|
||||||
while idx < len(lines) and (
|
|
||||||
is_empty_line(lines[idx])
|
|
||||||
or find_indent(lines[idx]) >= return_indent
|
|
||||||
):
|
|
||||||
idx += 1
|
|
||||||
else:
|
|
||||||
while idx < len(lines) and find_indent(lines[idx]) == return_indent:
|
|
||||||
return_type, return_description = split_raise_line(lines[idx])
|
|
||||||
raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
|
|
||||||
lines[idx] = "- " + raised_error + " -- " + return_description
|
|
||||||
md_link = _re_md_link.match(raised_error)
|
|
||||||
if md_link:
|
|
||||||
raised_error = md_link[1]
|
|
||||||
raised_error = re.sub(
|
|
||||||
r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
|
|
||||||
)
|
|
||||||
if raised_error not in raised_errors:
|
|
||||||
raised_errors.append(raised_error)
|
|
||||||
idx += 1
|
|
||||||
while idx < len(lines) and (
|
|
||||||
is_empty_line(lines[idx])
|
|
||||||
or find_indent(lines[idx]) > return_indent
|
|
||||||
):
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
lines.insert(idx, f"</{tag}s>\n")
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# Return block finished, we insert the return type if one was specified
|
|
||||||
if tag in ["return", "yield"] and return_type is not None:
|
|
||||||
lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
|
|
||||||
elif len(raised_errors) > 0:
|
|
||||||
# raised errors
|
|
||||||
lines[
|
|
||||||
idx - 1
|
|
||||||
] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
|
|
||||||
|
|
||||||
else:
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
result = "\n".join(lines)
|
|
||||||
|
|
||||||
# combine multiple <parameters> blocks into one block
|
|
||||||
if result.count("<parameters>") > 1:
|
|
||||||
parameters_blocks = _re_parameters.findall(result)
|
|
||||||
parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
|
|
||||||
parameters_str = "\n".join(parameters_blocks)
|
|
||||||
result = _re_parameters.sub("", result)
|
|
||||||
result += f"\n<parameters>{parameters_str}</parameters>\n"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
|
|
||||||
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
|
|
||||||
|
|
||||||
|
|
||||||
def remove_indent(text):
|
|
||||||
"""
|
|
||||||
Remove indents in text, except the one linked to lists (or sublists).
|
|
||||||
"""
|
|
||||||
lines = text.split("\n")
|
|
||||||
# List of indents to remember for nested lists
|
|
||||||
current_indents = []
|
|
||||||
# List of new indents to remember for nested lists
|
|
||||||
new_indents = []
|
|
||||||
is_inside_code = False
|
|
||||||
code_indent = 0
|
|
||||||
for idx, line in enumerate(lines):
|
|
||||||
# Line is an item in a list.
|
|
||||||
if _re_list.search(line) is not None:
|
|
||||||
indent = find_indent(line)
|
|
||||||
# Is it a new list / new level of nestedness?
|
|
||||||
if len(current_indents) == 0 or indent > current_indents[-1]:
|
|
||||||
current_indents.append(indent)
|
|
||||||
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
|
|
||||||
lines[idx] = " " * new_indent + line[indent:]
|
|
||||||
new_indent += len(_re_list.search(line).groups()[0]) + 1
|
|
||||||
new_indents.append(new_indent)
|
|
||||||
# Otherwise it's an existing level of list (current one, or previous one)
|
|
||||||
else:
|
|
||||||
# Let's find the proper level of indentation
|
|
||||||
level = len(current_indents) - 1
|
|
||||||
while level >= 0 and current_indents[level] != indent:
|
|
||||||
level -= 1
|
|
||||||
current_indents = current_indents[: level + 1]
|
|
||||||
new_indents = new_indents[:level]
|
|
||||||
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
|
|
||||||
lines[idx] = " " * new_indent + line[indent:]
|
|
||||||
new_indent += len(_re_list.search(line).groups()[0]) + 1
|
|
||||||
new_indents.append(new_indent)
|
|
||||||
|
|
||||||
# Line is an autodoc, we keep the indent for the list just after if there is one.
|
|
||||||
elif _re_autodoc.search(line) is not None:
|
|
||||||
indent = find_indent(line)
|
|
||||||
current_indents = [indent]
|
|
||||||
new_indents = [4]
|
|
||||||
lines[idx] = line.strip()
|
|
||||||
|
|
||||||
# Deal with empty lines separately
|
|
||||||
elif is_empty_line(line):
|
|
||||||
lines[idx] = ""
|
|
||||||
|
|
||||||
# Code blocks
|
|
||||||
elif line.lstrip().startswith("```"):
|
|
||||||
is_inside_code = not is_inside_code
|
|
||||||
if is_inside_code:
|
|
||||||
code_indent = find_indent(line)
|
|
||||||
lines[idx] = line[code_indent:]
|
|
||||||
elif is_inside_code:
|
|
||||||
lines[idx] = line[code_indent:]
|
|
||||||
|
|
||||||
else:
|
|
||||||
indent = find_indent(line)
|
|
||||||
if len(current_indents) > 0 and indent > current_indents[-1]:
|
|
||||||
lines[idx] = " " * new_indents[-1] + line[indent:]
|
|
||||||
elif len(current_indents) > 0:
|
|
||||||
# Let's find the proper level of indentation
|
|
||||||
level = len(current_indents) - 1
|
|
||||||
while level >= 0 and current_indents[level] > indent:
|
|
||||||
level -= 1
|
|
||||||
current_indents = current_indents[: level + 1]
|
|
||||||
if level >= 0:
|
|
||||||
if current_indents[level] < indent:
|
|
||||||
new_indents = new_indents[: level + 1]
|
|
||||||
else:
|
|
||||||
new_indents = new_indents[:level]
|
|
||||||
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
|
|
||||||
lines[idx] = " " * new_indent + line[indent:]
|
|
||||||
new_indents.append(new_indent)
|
|
||||||
else:
|
|
||||||
new_indents = []
|
|
||||||
lines[idx] = line[indent:]
|
|
||||||
else:
|
|
||||||
lines[idx] = line[indent:]
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
def base_rst_to_mdx(text, page_info, unindent=True):
|
|
||||||
"""
|
|
||||||
Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
|
|
||||||
"""
|
|
||||||
text = convert_rst_links(text, page_info)
|
|
||||||
text = convert_special_chars(text)
|
|
||||||
text = convert_rst_blocks(text, page_info)
|
|
||||||
# Convert * in lists to - to avoid the formatting conversion treat them as bold.
|
|
||||||
text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
|
|
||||||
text = convert_rst_formatting(text)
|
|
||||||
return remove_indent(text) if unindent else text
|
|
||||||
|
|
||||||
|
|
||||||
def convert_rst_docstring_to_mdx(docstring, page_info):
|
|
||||||
"""
|
|
||||||
Convert a docstring written in rst to mdx.
|
|
||||||
"""
|
|
||||||
text = parse_rst_docstring(docstring)
|
|
||||||
return base_rst_to_mdx(text, page_info)
|
|
||||||
|
|
||||||
|
|
||||||
def process_titles(lines):
|
|
||||||
"""Converts rst titles to markdown titles."""
|
|
||||||
title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
|
|
||||||
title_levels = {}
|
|
||||||
new_lines = []
|
|
||||||
for line in lines:
|
|
||||||
if (
|
|
||||||
len(new_lines) > 0
|
|
||||||
and len(line) >= len(new_lines[-1])
|
|
||||||
and len(set(line)) == 1
|
|
||||||
and line[0] in title_chars
|
|
||||||
and line != "::"
|
|
||||||
):
|
|
||||||
char = line[0]
|
|
||||||
level = title_levels.get(char, len(title_levels) + 1)
|
|
||||||
if level not in title_levels:
|
|
||||||
title_levels[char] = level
|
|
||||||
new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
|
|
||||||
else:
|
|
||||||
new_lines.append(line)
|
|
||||||
return new_lines
|
|
||||||
|
|
||||||
|
|
||||||
# Matches lines with a pattern of a table new line in rst.
|
|
||||||
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
|
|
||||||
# Matches lines with a pattern of a table new line in rst, with a first column empty.
|
|
||||||
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
|
|
||||||
# Matches lines with a pattern of a first table line in rst.
|
|
||||||
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
|
|
||||||
# Re pattern that catches anchors of the type .. reference:
|
|
||||||
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
|
|
||||||
|
|
||||||
|
|
||||||
def split_pt_tf_code_blocks(text):
|
|
||||||
"""
|
|
||||||
Split PyTorch and TensorFlow specific block codes.
|
|
||||||
"""
|
|
||||||
lines = text.split("\n")
|
|
||||||
new_lines = []
|
|
||||||
idx = 0
|
|
||||||
while idx < len(lines):
|
|
||||||
if lines[idx].startswith("```"):
|
|
||||||
code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
|
|
||||||
is_pytorch = False
|
|
||||||
is_tensorflow = False
|
|
||||||
idx += 1
|
|
||||||
while idx < len(lines) and lines[idx].strip() != "```":
|
|
||||||
if "## PYTORCH CODE" in lines[idx]:
|
|
||||||
is_pytorch = True
|
|
||||||
is_tensorflow = False
|
|
||||||
elif "## TENSORFLOW CODE" in lines[idx]:
|
|
||||||
is_tensorflow = True
|
|
||||||
is_pytorch = False
|
|
||||||
elif is_pytorch:
|
|
||||||
code_lines["pytorch"].append(lines[idx])
|
|
||||||
elif is_tensorflow:
|
|
||||||
code_lines["tensorflow"].append(lines[idx])
|
|
||||||
else:
|
|
||||||
code_lines["common"].append(lines[idx])
|
|
||||||
idx += 1
|
|
||||||
if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
|
|
||||||
block_lines = ["<frameworkcontent>", "<pt>"]
|
|
||||||
block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
|
|
||||||
block_lines.extend(["```", "</pt>", "<tf>"])
|
|
||||||
block_lines.extend(
|
|
||||||
code_lines["common"].copy() + code_lines["tensorflow"]
|
|
||||||
)
|
|
||||||
block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
|
|
||||||
new_lines.extend(block_lines)
|
|
||||||
else:
|
|
||||||
block_lines = code_lines["common"] + ["```"]
|
|
||||||
new_lines.extend(block_lines)
|
|
||||||
idx += 1
|
|
||||||
else:
|
|
||||||
new_lines.append(lines[idx])
|
|
||||||
idx += 1
|
|
||||||
return "\n".join(new_lines)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
|
|
||||||
"""
|
|
||||||
Convert a document written in rst to mdx.
|
|
||||||
"""
|
|
||||||
lines = rst_text.split("\n")
|
|
||||||
lines = process_titles(lines)
|
|
||||||
if add_imports:
|
|
||||||
new_lines = [
|
|
||||||
'<script lang="ts">',
|
|
||||||
' import Tip from "$lib/Tip.svelte";',
|
|
||||||
' import Youtube from "$lib/Youtube.svelte";',
|
|
||||||
' import Docstring from "$lib/Docstring.svelte";',
|
|
||||||
' import CodeBlock from "$lib/CodeBlock.svelte";',
|
|
||||||
' import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
|
|
||||||
' import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
|
|
||||||
' import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
|
|
||||||
' import IconCopyLink from "$lib/IconCopyLink.svelte";',
|
|
||||||
' import FrameworkContent from "$lib/FrameworkContent.svelte";',
|
|
||||||
' import Markdown from "$lib/Markdown.svelte";',
|
|
||||||
' import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
|
|
||||||
' import Added from "$lib/Added.svelte";',
|
|
||||||
' import Changed from "$lib/Changed.svelte";',
|
|
||||||
' import Deprecated from "$lib/Deprecated.svelte";',
|
|
||||||
' import PipelineIcon from "$lib/PipelineIcon.svelte";',
|
|
||||||
' import PipelineTag from "$lib/PipelineTag.svelte";',
|
|
||||||
" ",
|
|
||||||
' export let fw: "pt" | "tf"',
|
|
||||||
"</script>",
|
|
||||||
"<svelte:head>",
|
|
||||||
'<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
|
|
||||||
"</svelte:head>",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
new_lines = []
|
|
||||||
for line in lines:
|
|
||||||
if _re_ignore_line_table.search(line) is not None:
|
|
||||||
continue
|
|
||||||
elif _re_ignore_line_table1.search(line) is not None:
|
|
||||||
continue
|
|
||||||
elif _re_sep_line_table.search(line) is not None:
|
|
||||||
line = line.replace("=", "-").replace("+", "|")
|
|
||||||
elif _re_anchor_section.search(line) is not None:
|
|
||||||
anchor_name = _re_anchor_section.search(line).groups()[0]
|
|
||||||
line = f"<a id='{anchor_name}'></a>"
|
|
||||||
new_lines.append(line)
|
|
||||||
text = "\n".join(new_lines)
|
|
||||||
|
|
||||||
return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))
|
|
||||||
@ -8,9 +8,6 @@ from kernels.compat import tomllib
|
|||||||
from kernels.lockfile import KernelLock, get_kernel_locks
|
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||||
from kernels.utils import install_kernel, install_kernel_all_variants
|
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||||
|
|
||||||
from .doc import generate_readme_for_kernel
|
|
||||||
from .wheel import build_variant_to_wheel
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -39,47 +36,6 @@ def main():
|
|||||||
)
|
)
|
||||||
lock_parser.set_defaults(func=lock_kernels)
|
lock_parser.set_defaults(func=lock_kernels)
|
||||||
|
|
||||||
to_wheel_parser = subparsers.add_parser(
|
|
||||||
"to-wheel", help="Convert a kernel to a wheel file"
|
|
||||||
)
|
|
||||||
to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
|
|
||||||
to_wheel_parser.add_argument("version", type=str, help="The kernel version")
|
|
||||||
to_wheel_parser.add_argument(
|
|
||||||
"--python-version",
|
|
||||||
type=str,
|
|
||||||
default="3.9",
|
|
||||||
help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
|
|
||||||
)
|
|
||||||
to_wheel_parser.add_argument(
|
|
||||||
"--manylinux-version",
|
|
||||||
type=str,
|
|
||||||
default="2.28",
|
|
||||||
help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
|
|
||||||
)
|
|
||||||
to_wheel_parser.set_defaults(func=kernels_to_wheel)
|
|
||||||
|
|
||||||
# Add generate-readme subcommand parser
|
|
||||||
generate_readme_parser = subparsers.add_parser(
|
|
||||||
"generate-readme",
|
|
||||||
help="Generate README snippets for a kernel's public functions",
|
|
||||||
)
|
|
||||||
generate_readme_parser.add_argument(
|
|
||||||
"repo_id",
|
|
||||||
type=str,
|
|
||||||
help="The kernel repo ID (e.g., kernels-community/activation)",
|
|
||||||
)
|
|
||||||
generate_readme_parser.add_argument(
|
|
||||||
"--revision",
|
|
||||||
type=str,
|
|
||||||
default="main",
|
|
||||||
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
|
|
||||||
)
|
|
||||||
generate_readme_parser.set_defaults(
|
|
||||||
func=lambda args: generate_readme_for_kernel(
|
|
||||||
repo_id=args.repo_id, revision=args.revision
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.func(args)
|
args.func(args)
|
||||||
|
|
||||||
@ -121,24 +77,6 @@ def download_kernels(args):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def kernels_to_wheel(args):
|
|
||||||
variants_path = install_kernel_all_variants(
|
|
||||||
repo_id=args.repo_id, revision=f"v{args.version}"
|
|
||||||
)
|
|
||||||
for variant_path in variants_path.iterdir():
|
|
||||||
if not variant_path.is_dir():
|
|
||||||
continue
|
|
||||||
wheel_path = build_variant_to_wheel(
|
|
||||||
manylinux_version=args.manylinux_version,
|
|
||||||
python_version=args.python_version,
|
|
||||||
repo_id=args.repo_id,
|
|
||||||
version=args.version,
|
|
||||||
variant_path=variant_path,
|
|
||||||
wheel_dir=Path("."),
|
|
||||||
)
|
|
||||||
print(f"☸️ {wheel_path.name}", file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def lock_kernels(args):
|
def lock_kernels(args):
|
||||||
with open(args.project_dir / "pyproject.toml", "rb") as f:
|
with open(args.project_dir / "pyproject.toml", "rb") as f:
|
||||||
data = tomllib.load(f)
|
data = tomllib.load(f)
|
||||||
|
|||||||
@ -1,242 +0,0 @@
|
|||||||
import inspect
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from types import ModuleType
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
|
|
||||||
from .utils import get_kernel
|
|
||||||
|
|
||||||
_RE_PARAMETERS = re.compile(
|
|
||||||
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
|
|
||||||
)
|
|
||||||
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
|
|
||||||
_RE_RETURNTYPE = re.compile(
|
|
||||||
r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_description_before_tags(docstring_mdx: str) -> str:
|
|
||||||
"""Extract the description part of a docstring before any tags."""
|
|
||||||
params_pos = docstring_mdx.find("<parameters>")
|
|
||||||
returns_pos = docstring_mdx.find("<returns>")
|
|
||||||
returntype_pos = docstring_mdx.find("<returntype>")
|
|
||||||
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
|
|
||||||
|
|
||||||
if positions:
|
|
||||||
first_tag_pos = min(positions)
|
|
||||||
return docstring_mdx[:first_tag_pos].strip()
|
|
||||||
else:
|
|
||||||
return docstring_mdx.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
|
|
||||||
"""Print the parameters section from a docstring."""
|
|
||||||
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
|
||||||
if matches:
|
|
||||||
header = "#" * header_level
|
|
||||||
print(f"\n{header} Parameters")
|
|
||||||
for match in matches:
|
|
||||||
print(f"\n{match[0].strip()}")
|
|
||||||
|
|
||||||
|
|
||||||
def _print_returns_section(
|
|
||||||
docstring_mdx: str, *, context_name: str, header_level: int
|
|
||||||
) -> None:
|
|
||||||
"""Print the returns section from a docstring."""
|
|
||||||
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
|
||||||
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
|
||||||
|
|
||||||
if return_matches or returntype_matches:
|
|
||||||
header = "#" * header_level
|
|
||||||
print(f"\n{header} Returns")
|
|
||||||
|
|
||||||
if returntype_matches:
|
|
||||||
if len(returntype_matches) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"More than one <returntype> tag found in docstring for {context_name}"
|
|
||||||
)
|
|
||||||
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
|
|
||||||
|
|
||||||
if return_matches:
|
|
||||||
for match in return_matches:
|
|
||||||
print(f"\n{match[0].strip()}")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_docstring(obj, use_dict_check: bool = False) -> str:
|
|
||||||
"""Get docstring from an object, with fallback to default message."""
|
|
||||||
# Check whether the class/method itself has docs and not just
|
|
||||||
# the superclass.
|
|
||||||
if use_dict_check:
|
|
||||||
has_doc = obj.__dict__.get("__doc__", None) is not None
|
|
||||||
else:
|
|
||||||
has_doc = getattr(obj, "__doc__", None) is not None
|
|
||||||
|
|
||||||
# We use inspect.getdoc because it does normalization.
|
|
||||||
doc = inspect.getdoc(obj)
|
|
||||||
|
|
||||||
return doc if has_doc and doc is not None else "No documentation available."
|
|
||||||
|
|
||||||
|
|
||||||
def _process_and_print_docstring(
|
|
||||||
docstring: str, *, kernel_name: str, context_name: str, header_level: int
|
|
||||||
) -> None:
|
|
||||||
"""Convert docstring to MDX and print description, parameters, and returns sections."""
|
|
||||||
docstring_mdx = convert_rst_docstring_to_mdx(
|
|
||||||
docstring, page_info={"package_name": kernel_name}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print the description
|
|
||||||
description = _extract_description_before_tags(docstring_mdx)
|
|
||||||
print(f"\n{description}")
|
|
||||||
|
|
||||||
# Print parameters and returns sections
|
|
||||||
_print_parameters_section(docstring_mdx, header_level=header_level)
|
|
||||||
_print_returns_section(
|
|
||||||
docstring_mdx, context_name=context_name, header_level=header_level
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
|
||||||
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
|
|
||||||
kernel_name = repo_id.split("/")[-1].replace("-", "_")
|
|
||||||
|
|
||||||
generate_metadata(kernel_module)
|
|
||||||
generate_kernel_doc(kernel_module, kernel_name)
|
|
||||||
generate_function_doc(kernel_module, kernel_name)
|
|
||||||
generate_layers_doc(kernel_module, kernel_name)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_metadata(module: ModuleType) -> None:
|
|
||||||
metadata = getattr(module, "__kernel_metadata__", {})
|
|
||||||
if "tags" not in metadata:
|
|
||||||
metadata["tags"] = ["kernel"]
|
|
||||||
else:
|
|
||||||
if "kernel" not in metadata["tags"]:
|
|
||||||
metadata["tags"].append("kernel")
|
|
||||||
|
|
||||||
print("---")
|
|
||||||
print(yaml.dump(metadata), end="")
|
|
||||||
print("---")
|
|
||||||
|
|
||||||
|
|
||||||
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
|
|
||||||
docstring = module.__doc__.strip() if module.__doc__ is not None else None
|
|
||||||
if docstring:
|
|
||||||
title, rest = docstring.split("\n", 1)
|
|
||||||
print(f"# {title.strip()}")
|
|
||||||
print(
|
|
||||||
f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
|
||||||
print("\n## Functions")
|
|
||||||
|
|
||||||
# Track if we found any functions
|
|
||||||
found_functions = False
|
|
||||||
|
|
||||||
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
|
||||||
# Do not include imported functions.
|
|
||||||
if func.__module__ != kernel_module.__name__:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Exclude private functions.
|
|
||||||
if name.startswith("_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
found_functions = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
sig = inspect.signature(func)
|
|
||||||
docstring = _get_docstring(func)
|
|
||||||
except ValueError:
|
|
||||||
print(
|
|
||||||
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"\n### Function `{name}`")
|
|
||||||
print(f"\n`{sig}`")
|
|
||||||
|
|
||||||
_process_and_print_docstring(
|
|
||||||
docstring, kernel_name=kernel_name, context_name=name, header_level=3
|
|
||||||
)
|
|
||||||
|
|
||||||
if not found_functions:
|
|
||||||
print("\nNo public top-level functions.")
|
|
||||||
|
|
||||||
|
|
||||||
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
|
||||||
# Check if layers module is available
|
|
||||||
layers_module = getattr(kernel_module, "layers", None)
|
|
||||||
if layers_module is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
print("\n## Layers")
|
|
||||||
|
|
||||||
# Track if we found any classes
|
|
||||||
found_classes = False
|
|
||||||
|
|
||||||
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
|
|
||||||
# Exclude classes that were imported.
|
|
||||||
if cls.__module__ != layers_module.__name__:
|
|
||||||
continue
|
|
||||||
|
|
||||||
found_classes = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get docstring, but not from superclasses.
|
|
||||||
class_docstring = _get_docstring(cls, use_dict_check=True)
|
|
||||||
except Exception:
|
|
||||||
print(
|
|
||||||
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"\n### Class `{class_name}`")
|
|
||||||
|
|
||||||
# Always print class description (helper handles conversion and formatting)
|
|
||||||
class_docstring_mdx = convert_rst_docstring_to_mdx(
|
|
||||||
class_docstring, page_info={"package_name": kernel_name}
|
|
||||||
)
|
|
||||||
description = _extract_description_before_tags(class_docstring_mdx)
|
|
||||||
print(f"\n{description}")
|
|
||||||
|
|
||||||
# Document methods
|
|
||||||
print("\n#### Methods")
|
|
||||||
|
|
||||||
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
|
|
||||||
# Note: also skip __init__, since extension layers cannot have a constructor.
|
|
||||||
if method_name.startswith("_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip methods from superclasses.
|
|
||||||
if method_name not in cls.__dict__:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
sig = inspect.signature(method)
|
|
||||||
method_docstring = _get_docstring(method)
|
|
||||||
except ValueError:
|
|
||||||
print(
|
|
||||||
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"\n##### Method `{method_name}`")
|
|
||||||
print(f"\n`{sig}`")
|
|
||||||
|
|
||||||
_process_and_print_docstring(
|
|
||||||
method_docstring,
|
|
||||||
kernel_name=kernel_name,
|
|
||||||
context_name=method_name,
|
|
||||||
header_level=6,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not found_classes:
|
|
||||||
print("\nNo layers defined.")
|
|
||||||
@ -1,67 +1,19 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Flag, auto
|
from typing import TYPE_CHECKING, Dict, Union
|
||||||
from types import MethodType
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Dict,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import get_kernel
|
from .utils import get_kernel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||||
|
|
||||||
|
|
||||||
class Mode(Flag):
|
|
||||||
"""
|
|
||||||
Kernelize mode
|
|
||||||
|
|
||||||
The `Mode` flag is used by `kernelize` to select kernels for the given
|
|
||||||
mode. Mappings can be registered for specific modes.
|
|
||||||
|
|
||||||
* `INFERENCE`: The kernel is used for inference.
|
|
||||||
* `TRAINING`: The kernel is used for training.
|
|
||||||
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
|
|
||||||
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
|
|
||||||
matches.
|
|
||||||
|
|
||||||
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
|
|
||||||
should be used for layers that are used for inference *with* `torch.compile`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_NONE = 0
|
|
||||||
DEFAULT = auto()
|
|
||||||
TRAINING = auto()
|
|
||||||
INFERENCE = auto()
|
|
||||||
TORCH_COMPILE = auto()
|
|
||||||
|
|
||||||
def __or__(self, other: Mode) -> Mode:
|
|
||||||
union = super().__or__(other)
|
|
||||||
|
|
||||||
if Mode.INFERENCE in union and Mode.TRAINING in union:
|
|
||||||
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
|
|
||||||
|
|
||||||
if Mode.DEFAULT in union and union != Mode.DEFAULT:
|
|
||||||
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
|
|
||||||
|
|
||||||
return union
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Device:
|
class Device:
|
||||||
type: str
|
type: str
|
||||||
@ -101,19 +53,13 @@ class LayerRepository:
|
|||||||
return hash((self.layer_name, self.repo_id, self.revision))
|
return hash((self.layer_name, self.repo_id, self.revision))
|
||||||
|
|
||||||
|
|
||||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
|
||||||
|
"_KERNEL_MAPPING", default={}
|
||||||
|
|
||||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
|
|
||||||
ContextVar("_KERNEL_MAPPING", default={})
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def use_kernel_mapping(
|
def use_kernel_mapping(
|
||||||
mapping: Dict[
|
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
|
||||||
str,
|
|
||||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
|
||||||
],
|
|
||||||
*,
|
*,
|
||||||
inherit_mapping: bool = True,
|
inherit_mapping: bool = True,
|
||||||
):
|
):
|
||||||
@ -141,17 +87,12 @@ def use_kernel_mapping(
|
|||||||
|
|
||||||
|
|
||||||
def register_kernel_mapping(
|
def register_kernel_mapping(
|
||||||
mapping: Dict[
|
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
|
||||||
str,
|
|
||||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Allows one to register a mapping between a layer name and the corresponding
|
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
|
||||||
kernel(s) to use, depending on the device. This should be used in conjunction
|
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
|
||||||
with `kernelize`.
|
Exemple usage:
|
||||||
|
|
||||||
Example usage:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from kernels import LayerRepository, register_kernel_mapping
|
from kernels import LayerRepository, register_kernel_mapping
|
||||||
@ -172,106 +113,32 @@ def register_kernel_mapping(
|
|||||||
for new_kernel, new_device_repos in mapping.items():
|
for new_kernel, new_device_repos in mapping.items():
|
||||||
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
||||||
for new_device, new_repo in new_device_repos.items():
|
for new_device, new_repo in new_device_repos.items():
|
||||||
device = (
|
if isinstance(new_device, str):
|
||||||
Device(type=new_device) if isinstance(new_device, str) else new_device
|
device_repo[Device(type=new_device)] = new_repo
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(new_repo, LayerRepository):
|
|
||||||
kernel_options = {Mode.DEFAULT: new_repo}
|
|
||||||
else:
|
else:
|
||||||
kernel_options = new_repo
|
device_repo[new_device] = new_repo
|
||||||
|
|
||||||
device_repo[device] = kernel_options
|
|
||||||
|
|
||||||
|
|
||||||
def replace_kernel_forward_from_hub(
|
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
|
||||||
cls,
|
|
||||||
layer_name: str,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
|
Replace the forward function of a layer using a layer from the kernel hub.
|
||||||
|
This function monkeypatches a layer, replacing the `forward` method
|
||||||
This decorator stores the layer name and original forward method, which will be used
|
of the layer with that of a layer from the hub. The replacement is done
|
||||||
by the kernelize function to replace the forward implementation with the appropriate
|
when a layer matching `layer_name` and device type is registered through
|
||||||
kernel from the hub.
|
`register_layer_mapping`. The device type is inferred from the first
|
||||||
|
argument to `forward`.
|
||||||
Args:
|
|
||||||
cls: The layer class to decorate
|
|
||||||
layer_name: The name of the layer to use for kernel lookup
|
|
||||||
"""
|
"""
|
||||||
cls.kernel_layer_name = layer_name
|
|
||||||
|
|
||||||
|
fallback_forward = cls.forward
|
||||||
|
|
||||||
def _select_repository(
|
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
||||||
repositories: Dict[Mode, LayerRepository],
|
|
||||||
*,
|
|
||||||
mode: Mode,
|
|
||||||
) -> Optional[Tuple[LayerRepository, Mode]]:
|
|
||||||
if mode in repositories:
|
|
||||||
return (repositories[mode], mode)
|
|
||||||
elif Mode.DEFAULT in repositories:
|
|
||||||
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def kernelize(
|
|
||||||
model: "nn.Module",
|
|
||||||
*,
|
|
||||||
mode: Mode,
|
|
||||||
device: Optional[Union[str, "torch.device"]] = None,
|
|
||||||
use_fallback: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Iterate over all modules in the model and replace the `forward` method of
|
|
||||||
extensible layers for which kernels are registered using `register_kernel_mapping`
|
|
||||||
or `use_kernel_mapping`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The PyTorch model to kernelize
|
|
||||||
mode: the mode that the kernel is going to be used in (e.g.
|
|
||||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
|
|
||||||
and `torch.compile`).
|
|
||||||
device: The device type to load kernels for. The device type will be inferred
|
|
||||||
from the parameters of the model when not provided.
|
|
||||||
use_fallback: Whether to use the original forward method of modules when no
|
|
||||||
compatible kernel could be found. If set to `False`, an exception will
|
|
||||||
be raised in such cases.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The kernelized model
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if mode == Mode.DEFAULT:
|
|
||||||
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
|
|
||||||
|
|
||||||
# Type check ignored because this causes a false negative on Python < 3.11.
|
|
||||||
# Looks similar to: https://github.com/python/mypy/issues/9642
|
|
||||||
# Remove once we start doing typing checks on >= 3.11.
|
|
||||||
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
|
|
||||||
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
|
|
||||||
|
|
||||||
if device is None:
|
|
||||||
device_type = _find_device(model)
|
|
||||||
elif isinstance(device, str):
|
|
||||||
device_type = Device(type=torch.device(device).type)
|
|
||||||
else:
|
|
||||||
device_type = Device(device.type)
|
|
||||||
assert isinstance(device_type, Device)
|
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
module_class = type(module)
|
|
||||||
if not hasattr(module_class, "kernel_layer_name"):
|
|
||||||
continue
|
|
||||||
layer_name = module_class.kernel_layer_name
|
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
if _DISABLE_KERNEL_MAPPING:
|
if _DISABLE_KERNEL_MAPPING:
|
||||||
_replace_forward(module, module_class)
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
continue
|
|
||||||
|
|
||||||
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
|
|
||||||
|
|
||||||
|
needs_backward = self.training
|
||||||
|
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||||
if kernel is None:
|
if kernel is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"\n"
|
"\n"
|
||||||
@ -281,70 +148,68 @@ def kernelize(
|
|||||||
)
|
)
|
||||||
if not use_fallback:
|
if not use_fallback:
|
||||||
raise ValueError(f"No layer mapping for `{layer_name}`")
|
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||||
_replace_forward(module, module_class)
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
continue
|
|
||||||
|
|
||||||
# Get kernel options for the device
|
device = getattr(x, "device", None)
|
||||||
repos = kernel.get(device_type)
|
if device is None:
|
||||||
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
|
|
||||||
if repos is None:
|
repo = kernel.get(Device(type=device.type))
|
||||||
|
if repo is None:
|
||||||
if not use_fallback:
|
if not use_fallback:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
|
||||||
)
|
)
|
||||||
_replace_forward(module, module_class)
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
continue
|
|
||||||
|
|
||||||
repo_with_mode = _select_repository(
|
# Short-circuit if we already loaded the layer.
|
||||||
repos,
|
layer = cached_layer.get(repo, None)
|
||||||
mode=mode,
|
if layer is not None:
|
||||||
|
if needs_backward and not getattr(layer, "has_backward", True):
|
||||||
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
|
return layer.forward(self, x, *args, **kwargs)
|
||||||
|
|
||||||
|
layer = _get_kernel_layer(
|
||||||
|
repo_id=repo.repo_id,
|
||||||
|
layer_name=repo.layer_name,
|
||||||
|
revision=repo.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
if repo_with_mode is None:
|
# We have to validate against the original signature.
|
||||||
if not use_fallback:
|
orig_forward = cls.forward
|
||||||
raise ValueError(
|
try:
|
||||||
f"No repository for `{layer_name}` for configuration mode={mode}"
|
cls.forward = fallback_forward
|
||||||
)
|
_validate_layer(check_cls=cls, cls=layer)
|
||||||
_replace_forward(module, module_class)
|
finally:
|
||||||
continue
|
cls.forward = orig_forward
|
||||||
|
|
||||||
repo, repo_mode = repo_with_mode
|
cached_layer[repo] = layer
|
||||||
|
|
||||||
layer = _get_layer_memoize(repo, module_class)
|
if needs_backward and not getattr(layer, "has_backward", True):
|
||||||
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
|
return layer.forward(self, x, *args, **kwargs)
|
||||||
|
|
||||||
# Ideally we would do validation on the mapping where we check that
|
cls.forward = forward
|
||||||
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
|
|
||||||
# the actual layer is compatible with that. Unfortunately, this would
|
|
||||||
# mean that we have to pre-download everything.
|
|
||||||
_validate_layer_has_mode(
|
|
||||||
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
_conditionally_replace_forward(
|
|
||||||
module=module,
|
|
||||||
layer=layer,
|
|
||||||
mode=mode,
|
|
||||||
use_fallback=use_fallback,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def use_kernel_forward_from_hub(layer_name: str):
|
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
|
||||||
"""
|
"""
|
||||||
Make a layer extensible using the name `layer_name`.
|
Replace the forward function of a layer using a layer from the kernel hub.
|
||||||
|
This decorator can be applied to a layer and replaces the forward method
|
||||||
|
of the layer with that of a layer from the hub. The replacement is done
|
||||||
|
when a layer matching `layer_name` and device type is registered through
|
||||||
|
`register_layer_mapping`. The device type is inferred from the first
|
||||||
|
argument to `forward`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
replace_kernel_forward_from_hub(cls, layer_name)
|
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _get_kernel_layer(
|
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
|
||||||
*, repo_id: str, layer_name: str, revision: str
|
|
||||||
) -> Type["nn.Module"]:
|
|
||||||
"""Get a layer from a kernel."""
|
"""Get a layer from a kernel."""
|
||||||
|
|
||||||
kernel = get_kernel(repo_id, revision=revision)
|
kernel = get_kernel(repo_id, revision=revision)
|
||||||
@ -361,13 +226,13 @@ def _get_kernel_layer(
|
|||||||
|
|
||||||
|
|
||||||
def _validate_layer(*, check_cls, cls):
|
def _validate_layer(*, check_cls, cls):
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
# The layer must have at least have the following properties: (1) it
|
# The layer must have at least have the following properties: (1) it
|
||||||
# must be stateless; (2) the forward signature should correspond to
|
# must be stateless; (2) the forward signature should correspond to
|
||||||
# the signature it is replacing; (3) forward should not call other
|
# the signature it is replacing; (3) forward should not call other
|
||||||
# methods.
|
# methods.
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
if not issubclass(cls, nn.Module):
|
if not issubclass(cls, nn.Module):
|
||||||
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
||||||
|
|
||||||
@ -380,8 +245,7 @@ def _validate_layer(*, check_cls, cls):
|
|||||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||||
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||||
difference = cls_members - torch_module_members
|
difference = cls_members - torch_module_members
|
||||||
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
if difference != set() and difference != {"has_backward"}:
|
||||||
if not difference <= {"can_torch_compile", "has_backward"}:
|
|
||||||
raise TypeError("Layer must not contain additional members.")
|
raise TypeError("Layer must not contain additional members.")
|
||||||
|
|
||||||
# Check whether the forward signatures are similar.
|
# Check whether the forward signatures are similar.
|
||||||
@ -398,92 +262,3 @@ def _validate_layer(*, check_cls, cls):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _find_device(model: "nn.Module") -> Device:
|
|
||||||
try:
|
|
||||||
param = next(model.parameters())
|
|
||||||
except StopIteration:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot determine model device, provide as `device` argument to `kernelize`."
|
|
||||||
)
|
|
||||||
|
|
||||||
return Device(type=param.device.type)
|
|
||||||
|
|
||||||
|
|
||||||
def _conditionally_replace_forward(
|
|
||||||
*,
|
|
||||||
module: "nn.Module",
|
|
||||||
layer: Type["nn.Module"],
|
|
||||||
mode: Mode,
|
|
||||||
use_fallback: bool,
|
|
||||||
):
|
|
||||||
module_class = type(module)
|
|
||||||
|
|
||||||
# Switch to fallback if the mode is not supported by the layer.
|
|
||||||
# Note that this is useful even after _validate_layer_has_mode because
|
|
||||||
# layers registered with the DEFAULT mode never get rejected by
|
|
||||||
# _validate_layer_has_mode. For such layers, we want to fall back in
|
|
||||||
# case the layer does not support the given mode.
|
|
||||||
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
|
|
||||||
layer, "can_torch_compile", False
|
|
||||||
)
|
|
||||||
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
|
|
||||||
|
|
||||||
if needs_fallback:
|
|
||||||
if use_fallback:
|
|
||||||
_replace_forward(module, module_class)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Available kernel does not support mode: {mode}")
|
|
||||||
else:
|
|
||||||
_replace_forward(module, layer)
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
|
|
||||||
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_layer_has_mode(
|
|
||||||
*,
|
|
||||||
layer_name: str,
|
|
||||||
module: Type["nn.Module"],
|
|
||||||
repo: LayerRepository,
|
|
||||||
repo_mode: Mode,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Check that a repository supports the mode that it was registered for.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
|
|
||||||
raise ValueError(
|
|
||||||
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
|
|
||||||
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if Mode.TORCH_COMPILE in repo_mode and not getattr(
|
|
||||||
module, "can_torch_compile", False
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
|
|
||||||
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _get_layer_memoize(
|
|
||||||
repo: LayerRepository, module_class: Type["nn.Module"]
|
|
||||||
) -> Type["nn.Module"]:
|
|
||||||
layer = _CACHED_LAYER.get(repo, None)
|
|
||||||
if layer is not None:
|
|
||||||
return layer
|
|
||||||
|
|
||||||
layer = _get_kernel_layer(
|
|
||||||
repo_id=repo.repo_id,
|
|
||||||
layer_name=repo.layer_name,
|
|
||||||
revision=repo.revision,
|
|
||||||
)
|
|
||||||
_validate_layer(check_cls=module_class, cls=layer)
|
|
||||||
_CACHED_LAYER[repo] = layer
|
|
||||||
|
|
||||||
return layer
|
|
||||||
|
|||||||
@ -43,23 +43,14 @@ def build_variant() -> str:
|
|||||||
elif torch.version.hip is not None:
|
elif torch.version.hip is not None:
|
||||||
rocm_version = parse(torch.version.hip.split("-")[0])
|
rocm_version = parse(torch.version.hip.split("-")[0])
|
||||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
compute_framework = "metal"
|
|
||||||
else:
|
else:
|
||||||
raise AssertionError(
|
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
|
||||||
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_version = parse(torch.__version__)
|
torch_version = parse(torch.__version__)
|
||||||
|
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||||
cpu = platform.machine()
|
cpu = platform.machine()
|
||||||
os = platform.system().lower()
|
os = platform.system().lower()
|
||||||
|
|
||||||
if os == "darwin":
|
|
||||||
cpu = "aarch64" if cpu == "arm64" else cpu
|
|
||||||
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
|
|
||||||
|
|
||||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
|
||||||
|
|
||||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
|
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
|
||||||
|
|
||||||
|
|
||||||
@ -110,23 +101,6 @@ def install_kernel(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
return _load_kernel_from_path(repo_path, package_name, variant_locks)
|
|
||||||
except FileNotFoundError:
|
|
||||||
# Redo with more specific error message.
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_kernel_from_path(
|
|
||||||
repo_path: Path,
|
|
||||||
package_name: str,
|
|
||||||
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
|
||||||
) -> Tuple[str, Path]:
|
|
||||||
variant = build_variant()
|
|
||||||
universal_variant = universal_build_variant()
|
|
||||||
|
|
||||||
variant_path = repo_path / "build" / variant
|
variant_path = repo_path / "build" / variant
|
||||||
universal_variant_path = repo_path / "build" / universal_variant
|
universal_variant_path = repo_path / "build" / universal_variant
|
||||||
|
|
||||||
@ -145,7 +119,7 @@ def _load_kernel_from_path(
|
|||||||
|
|
||||||
if not os.path.exists(module_init_path):
|
if not os.path.exists(module_init_path):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Kernel at path `{repo_path}` does not have build: {variant}"
|
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return package_name, variant_path
|
return package_name, variant_path
|
||||||
@ -183,24 +157,10 @@ def install_kernel_all_variants(
|
|||||||
|
|
||||||
|
|
||||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||||
"""
|
|
||||||
Download and import a kernel from the Hugging Face Hub.
|
|
||||||
|
|
||||||
The kernel is downloaded from the repository `repo_id` at
|
|
||||||
branch/commit/tag `revision`.
|
|
||||||
"""
|
|
||||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
|
||||||
"""
|
|
||||||
Import a kernel from a local kernel repository path.
|
|
||||||
"""
|
|
||||||
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
|
|
||||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
|
||||||
|
|
||||||
|
|
||||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether a kernel build exists for the current environment
|
Check whether a kernel build exists for the current environment
|
||||||
|
|||||||
@ -1,186 +0,0 @@
|
|||||||
import email.policy
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from email.message import Message
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
try:
|
|
||||||
KERNELS_VERSION = version("kernels")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
KERNELS_VERSION = "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Metadata:
|
|
||||||
name: str
|
|
||||||
version: str
|
|
||||||
cuda_version: Optional[str]
|
|
||||||
cxx_abi_version: Optional[str]
|
|
||||||
torch_version: Optional[str]
|
|
||||||
os: Optional[str]
|
|
||||||
platform: Optional[str]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_universal(self) -> bool:
|
|
||||||
return self.platform is None
|
|
||||||
|
|
||||||
|
|
||||||
def build_variant_to_wheel(
|
|
||||||
repo_id: str,
|
|
||||||
*,
|
|
||||||
version: str,
|
|
||||||
variant_path: Path,
|
|
||||||
wheel_dir: Path,
|
|
||||||
manylinux_version: str = "2.28",
|
|
||||||
python_version: str = "3.9",
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Create a wheel file from the variant path.
|
|
||||||
"""
|
|
||||||
name = repo_id.split("/")[-1].replace("_", "-")
|
|
||||||
metadata = extract_metadata(name, version, variant_path)
|
|
||||||
return build_wheel(
|
|
||||||
metadata,
|
|
||||||
variant_path=variant_path,
|
|
||||||
wheel_dir=wheel_dir,
|
|
||||||
manylinux_version=manylinux_version,
|
|
||||||
python_version=python_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
|
|
||||||
"""
|
|
||||||
Extract metadata from the variant path.
|
|
||||||
"""
|
|
||||||
if variant_path.name == "torch-universal":
|
|
||||||
return Metadata(
|
|
||||||
name=name,
|
|
||||||
version=version,
|
|
||||||
cuda_version=None,
|
|
||||||
cxx_abi_version=None,
|
|
||||||
torch_version=None,
|
|
||||||
os=None,
|
|
||||||
platform=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not variant_path.name.startswith("torch"):
|
|
||||||
raise ValueError("Currently only conversion of Torch kernels is supported.")
|
|
||||||
|
|
||||||
variant_parts = variant_path.name.removeprefix("torch").split("-")
|
|
||||||
if len(variant_parts) != 5:
|
|
||||||
raise ValueError(f"Invalid variant name: {variant_path.name}")
|
|
||||||
|
|
||||||
torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
|
|
||||||
cpp_abi_version = variant_parts[1].removeprefix("cxx")
|
|
||||||
cuda_version = variant_parts[2].removeprefix("cu")
|
|
||||||
platform = variant_parts[3].replace("-", "_")
|
|
||||||
os = variant_parts[4]
|
|
||||||
|
|
||||||
return Metadata(
|
|
||||||
name=name,
|
|
||||||
version=version,
|
|
||||||
cuda_version=cuda_version,
|
|
||||||
cxx_abi_version=cpp_abi_version,
|
|
||||||
torch_version=torch_version,
|
|
||||||
os=os,
|
|
||||||
platform=platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_wheel(
|
|
||||||
metadata: Metadata,
|
|
||||||
*,
|
|
||||||
variant_path: Path,
|
|
||||||
wheel_dir: Path,
|
|
||||||
manylinux_version: str = "2.28",
|
|
||||||
python_version: str = "3.9",
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Build the wheel file.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from wheel.wheelfile import WheelFile # type: ignore
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
|
|
||||||
)
|
|
||||||
|
|
||||||
name = metadata.name.replace("-", "_")
|
|
||||||
python_version_flat = python_version.replace(".", "")
|
|
||||||
|
|
||||||
if metadata.is_universal:
|
|
||||||
python_tag = f"py{python_version_flat}"
|
|
||||||
abi_tag = "none"
|
|
||||||
platform_tag = "any"
|
|
||||||
wheel_filename = (
|
|
||||||
f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
|
|
||||||
)
|
|
||||||
dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
|
|
||||||
root_is_purelib = "true"
|
|
||||||
requires_dist_torch = "torch"
|
|
||||||
else:
|
|
||||||
python_tag = f"cp{python_version_flat}"
|
|
||||||
abi_tag = "abi3"
|
|
||||||
|
|
||||||
if (
|
|
||||||
metadata.torch_version is None
|
|
||||||
or metadata.cuda_version is None
|
|
||||||
or metadata.cxx_abi_version is None
|
|
||||||
or metadata.os is None
|
|
||||||
or metadata.platform is None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
|
|
||||||
)
|
|
||||||
|
|
||||||
local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
|
|
||||||
|
|
||||||
if metadata.os == "linux":
|
|
||||||
platform_tag = (
|
|
||||||
f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
|
|
||||||
|
|
||||||
wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
|
|
||||||
dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
|
|
||||||
root_is_purelib = "false"
|
|
||||||
requires_dist_torch = f"torch=={metadata.torch_version}.*"
|
|
||||||
|
|
||||||
wheel_path = wheel_dir / wheel_filename
|
|
||||||
|
|
||||||
wheel_msg = Message(email.policy.compat32)
|
|
||||||
wheel_msg.add_header("Wheel-Version", "1.0")
|
|
||||||
wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
|
|
||||||
wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
|
|
||||||
wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
|
|
||||||
|
|
||||||
metadata_msg = Message(email.policy.compat32)
|
|
||||||
metadata_msg.add_header("Metadata-Version", "2.1")
|
|
||||||
metadata_msg.add_header("Name", name)
|
|
||||||
metadata_msg.add_header("Version", metadata.version)
|
|
||||||
metadata_msg.add_header("Summary", f"{name} kernel")
|
|
||||||
metadata_msg.add_header("Requires-Python", ">=3.9")
|
|
||||||
metadata_msg.add_header("Requires-Dist", requires_dist_torch)
|
|
||||||
|
|
||||||
source_pkg_dir = variant_path / name
|
|
||||||
|
|
||||||
with WheelFile(wheel_path, "w") as wheel_file:
|
|
||||||
for root, dirnames, filenames in os.walk(source_pkg_dir):
|
|
||||||
for filename in filenames:
|
|
||||||
if filename.endswith(".pyc"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
abs_filepath = os.path.join(root, filename)
|
|
||||||
entry_name = os.path.relpath(abs_filepath, variant_path)
|
|
||||||
wheel_file.write(abs_filepath, entry_name)
|
|
||||||
|
|
||||||
wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
|
|
||||||
wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
|
|
||||||
|
|
||||||
metadata_path = os.path.join(dist_info_dir_name, "METADATA")
|
|
||||||
wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
|
|
||||||
|
|
||||||
return wheel_path
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_setup(item):
|
|
||||||
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
|
|
||||||
pytest.skip("skipping Linux-only test on non-Linux platform")
|
|
||||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
|
||||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
|
||||||
@ -1,82 +1,54 @@
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"repo_id": "kernels-community/activation",
|
"repo_id": "kernels-community/activation",
|
||||||
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
|
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
|
||||||
"variants": {
|
"variants": {
|
||||||
"torch25-cxx11-cu118-x86_64-linux": {
|
"torch25-cxx11-cu118-x86_64-linux": {
|
||||||
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
|
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch25-cxx11-cu121-x86_64-linux": {
|
"torch25-cxx11-cu121-x86_64-linux": {
|
||||||
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
|
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch25-cxx11-cu124-x86_64-linux": {
|
"torch25-cxx11-cu124-x86_64-linux": {
|
||||||
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
|
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch25-cxx98-cu118-x86_64-linux": {
|
"torch25-cxx98-cu118-x86_64-linux": {
|
||||||
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
|
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch25-cxx98-cu121-x86_64-linux": {
|
"torch25-cxx98-cu121-x86_64-linux": {
|
||||||
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
|
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch25-cxx98-cu124-x86_64-linux": {
|
"torch25-cxx98-cu124-x86_64-linux": {
|
||||||
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
|
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch26-cxx11-cu118-x86_64-linux": {
|
"torch26-cxx11-cu118-x86_64-linux": {
|
||||||
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
|
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch26-cxx11-cu124-x86_64-linux": {
|
"torch26-cxx11-cu124-x86_64-linux": {
|
||||||
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
|
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
|
||||||
"hash_type": "git_lfs_concat"
|
|
||||||
},
|
|
||||||
"torch26-cxx11-cu126-aarch64-linux": {
|
|
||||||
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
|
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch26-cxx11-cu126-x86_64-linux": {
|
"torch26-cxx11-cu126-x86_64-linux": {
|
||||||
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
|
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch26-cxx98-cu118-x86_64-linux": {
|
"torch26-cxx98-cu118-x86_64-linux": {
|
||||||
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
|
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch26-cxx98-cu124-x86_64-linux": {
|
"torch26-cxx98-cu124-x86_64-linux": {
|
||||||
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
|
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
|
||||||
"hash_type": "git_lfs_concat"
|
|
||||||
},
|
|
||||||
"torch26-cxx98-cu126-aarch64-linux": {
|
|
||||||
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
|
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
},
|
},
|
||||||
"torch26-cxx98-cu126-x86_64-linux": {
|
"torch26-cxx98-cu126-x86_64-linux": {
|
||||||
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
|
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
|
||||||
"hash_type": "git_lfs_concat"
|
|
||||||
},
|
|
||||||
"torch27-cxx11-cu118-x86_64-linux": {
|
|
||||||
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
|
|
||||||
"hash_type": "git_lfs_concat"
|
|
||||||
},
|
|
||||||
"torch27-cxx11-cu126-aarch64-linux": {
|
|
||||||
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
|
|
||||||
"hash_type": "git_lfs_concat"
|
|
||||||
},
|
|
||||||
"torch27-cxx11-cu126-x86_64-linux": {
|
|
||||||
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
|
|
||||||
"hash_type": "git_lfs_concat"
|
|
||||||
},
|
|
||||||
"torch27-cxx11-cu128-aarch64-linux": {
|
|
||||||
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
|
|
||||||
"hash_type": "git_lfs_concat"
|
|
||||||
},
|
|
||||||
"torch27-cxx11-cu128-x86_64-linux": {
|
|
||||||
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
|
|
||||||
"hash_type": "git_lfs_concat"
|
"hash_type": "git_lfs_concat"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
|
from kernels import get_kernel, has_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -9,19 +9,6 @@ def kernel():
|
|||||||
return get_kernel("kernels-community/activation")
|
return get_kernel("kernels-community/activation")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def local_kernel():
|
|
||||||
package_name, path = install_kernel("kernels-community/activation", "main")
|
|
||||||
# Path is the build variant path (build/torch-<...>), so the grandparent
|
|
||||||
# is the kernel repository path.
|
|
||||||
return get_local_kernel(path.parent.parent, package_name)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def metal_kernel():
|
|
||||||
return get_kernel("kernels-test/relu-metal")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def universal_kernel():
|
def universal_kernel():
|
||||||
return get_kernel("kernels-community/triton-scaled-mm")
|
return get_kernel("kernels-community/triton-scaled-mm")
|
||||||
@ -34,7 +21,6 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_gelu_fast(kernel, device):
|
def test_gelu_fast(kernel, device):
|
||||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
@ -50,31 +36,6 @@ def test_gelu_fast(kernel, device):
|
|||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_local_kernel(local_kernel, device):
|
|
||||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
|
||||||
y = torch.empty_like(x)
|
|
||||||
|
|
||||||
local_kernel.gelu_fast(y, x)
|
|
||||||
|
|
||||||
expected = torch.tensor(
|
|
||||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.allclose(y, expected)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.darwin_only
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
|
||||||
def test_relu_metal(metal_kernel, dtype):
|
|
||||||
x = torch.arange(-10, 10, dtype=dtype, device="mps")
|
|
||||||
y = metal_kernel.relu(x)
|
|
||||||
assert torch.allclose(y, torch.relu(x))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"kernel_exists",
|
"kernel_exists",
|
||||||
[
|
[
|
||||||
@ -91,7 +52,6 @@ def test_has_kernel(kernel_exists):
|
|||||||
assert has_kernel(repo_id, revision=revision) == kernel
|
assert has_kernel(repo_id, revision=revision) == kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_universal_kernel(universal_kernel):
|
def test_universal_kernel(universal_kernel):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||||
|
|||||||
@ -16,21 +16,18 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_gelu_small(kernel, device, benchmark):
|
def test_gelu_small(kernel, device, benchmark):
|
||||||
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_gelu_medium(kernel, device, benchmark):
|
def test_gelu_medium(kernel, device, benchmark):
|
||||||
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_gelu_large(kernel, device, benchmark):
|
def test_gelu_large(kernel, device, benchmark):
|
||||||
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from kernels import load_kernel
|
from kernels import load_kernel
|
||||||
from kernels.cli import download_kernels
|
from kernels.cli import download_kernels
|
||||||
|
|
||||||
@ -19,7 +17,6 @@ def test_download_all_hash_validation():
|
|||||||
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_load_locked():
|
def test_load_locked():
|
||||||
project_dir = Path(__file__).parent / "kernel_locking"
|
project_dir = Path(__file__).parent / "kernel_locking"
|
||||||
# Also validates that hashing works correctly.
|
# Also validates that hashing works correctly.
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -8,8 +6,6 @@ from torch.nn import functional as F
|
|||||||
from kernels import (
|
from kernels import (
|
||||||
Device,
|
Device,
|
||||||
LayerRepository,
|
LayerRepository,
|
||||||
Mode,
|
|
||||||
kernelize,
|
|
||||||
register_kernel_mapping,
|
register_kernel_mapping,
|
||||||
use_kernel_forward_from_hub,
|
use_kernel_forward_from_hub,
|
||||||
)
|
)
|
||||||
@ -20,18 +16,14 @@ kernel_layer_mapping = {
|
|||||||
Device(type="cuda"): LayerRepository(
|
Device(type="cuda"): LayerRepository(
|
||||||
repo_id="kernels-community/activation",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="SiluAndMul",
|
layer_name="SiluAndMul",
|
||||||
)
|
revision="layers",
|
||||||
},
|
|
||||||
"SiluAndMulNoCompile": {
|
|
||||||
"cuda": LayerRepository(
|
|
||||||
repo_id="kernels-test/op-without-fake-test",
|
|
||||||
layer_name="SiluAndMul",
|
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
"SiluAndMulStringDevice": {
|
"SiluAndMulStringDevice": {
|
||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
repo_id="kernels-community/activation",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="SiluAndMul",
|
layer_name="SiluAndMul",
|
||||||
|
revision="layers",
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -51,11 +43,6 @@ class SiluAndMul(nn.Module):
|
|||||||
return F.silu(input[..., :d]) * input[..., d:]
|
return F.silu(input[..., :d]) * input[..., d:]
|
||||||
|
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
|
|
||||||
class SiluAndMulNoCompileKernel(SiluAndMul):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("SiluAndMul")
|
@use_kernel_forward_from_hub("SiluAndMul")
|
||||||
class SiluAndMulWithKernel(SiluAndMul):
|
class SiluAndMulWithKernel(SiluAndMul):
|
||||||
pass
|
pass
|
||||||
@ -66,18 +53,6 @@ class SiluAndMulStringDevice(SiluAndMul):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("Linear")
|
|
||||||
class TorchLinearWithCounter(nn.Linear):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
# Used to check that we called hub kernel.
|
|
||||||
self.n_calls = 0
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
self.n_calls += 1
|
|
||||||
return super().forward(input)
|
|
||||||
|
|
||||||
|
|
||||||
def test_arg_kinds():
|
def test_arg_kinds():
|
||||||
@use_kernel_forward_from_hub("ArgKind")
|
@use_kernel_forward_from_hub("ArgKind")
|
||||||
class ArgKind(nn.Module):
|
class ArgKind(nn.Module):
|
||||||
@ -96,7 +71,6 @@ def test_arg_kinds():
|
|||||||
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||||
def test_hub_forward(cls, device):
|
def test_hub_forward(cls, device):
|
||||||
@ -106,7 +80,7 @@ def test_hub_forward(cls, device):
|
|||||||
X = torch.randn((32, 64), device=device)
|
X = torch.randn((32, 64), device=device)
|
||||||
Y = silu_and_mul(X)
|
Y = silu_and_mul(X)
|
||||||
|
|
||||||
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
silu_and_mul_with_kernel = cls()
|
||||||
Y_kernel = silu_and_mul_with_kernel(X)
|
Y_kernel = silu_and_mul_with_kernel(X)
|
||||||
|
|
||||||
torch.testing.assert_close(Y_kernel, Y)
|
torch.testing.assert_close(Y_kernel, Y)
|
||||||
@ -124,70 +98,11 @@ def test_layer_fallback_works():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Check that we don't raise an exception for a non-existing kernel.
|
# Check that we don't raise an exception for a non-existing kernel.
|
||||||
silu_and_mul = SiluAndMulWithKernelFallback()
|
SiluAndMulWithKernelFallback()
|
||||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
|
||||||
def test_torch_compile_layer_without_fallback(cls, device):
|
|
||||||
silu_and_mul = SiluAndMul()
|
|
||||||
|
|
||||||
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
|
||||||
Y = silu_and_mul(X)
|
|
||||||
|
|
||||||
silu_and_mul_with_kernel = cls()
|
|
||||||
silu_and_mul_with_kernel.eval()
|
|
||||||
|
|
||||||
ctx = (
|
|
||||||
pytest.raises(ValueError, match="does not support mode")
|
|
||||||
if cls is SiluAndMulNoCompileKernel
|
|
||||||
else nullcontext()
|
|
||||||
)
|
|
||||||
with ctx:
|
|
||||||
silu_and_mul_with_kernel = kernelize(
|
|
||||||
silu_and_mul_with_kernel,
|
|
||||||
device=device,
|
|
||||||
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
|
||||||
use_fallback=False,
|
|
||||||
)
|
|
||||||
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
|
||||||
|
|
||||||
Y_compiled = silu_and_mul_compiled(X)
|
|
||||||
|
|
||||||
torch.testing.assert_close(Y_compiled, Y)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
|
||||||
def test_torch_compile_layer_with_fallback(cls, device):
|
|
||||||
silu_and_mul = SiluAndMul()
|
|
||||||
|
|
||||||
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
|
||||||
Y = silu_and_mul(X)
|
|
||||||
|
|
||||||
silu_and_mul_with_kernel = cls()
|
|
||||||
silu_and_mul_with_kernel.eval()
|
|
||||||
silu_and_mul_with_kernel = kernelize(
|
|
||||||
silu_and_mul_with_kernel,
|
|
||||||
device=device,
|
|
||||||
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
|
||||||
)
|
|
||||||
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
|
||||||
|
|
||||||
Y_compiled = silu_and_mul_compiled(X)
|
|
||||||
|
|
||||||
torch.testing.assert_close(Y_compiled, Y)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mapping_contexts():
|
def test_mapping_contexts():
|
||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
|
||||||
"SiluAndMul",
|
|
||||||
"SiluAndMulStringDevice",
|
|
||||||
"SiluAndMulNoCompile",
|
|
||||||
}
|
|
||||||
|
|
||||||
extra_mapping1 = {
|
extra_mapping1 = {
|
||||||
"TestKernel": {
|
"TestKernel": {
|
||||||
@ -203,7 +118,6 @@ def test_mapping_contexts():
|
|||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,26 +135,20 @@ def test_mapping_contexts():
|
|||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||||
Mode.DEFAULT
|
|
||||||
].repo_id
|
|
||||||
== "kernels-community/non-existing"
|
== "kernels-community/non-existing"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||||
Mode.DEFAULT
|
|
||||||
].repo_id
|
|
||||||
== "kernels-community/activation"
|
== "kernels-community/activation"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -249,29 +157,23 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||||
Mode.DEFAULT
|
|
||||||
].repo_id
|
|
||||||
== "kernels-community/non-existing"
|
== "kernels-community/non-existing"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||||
Mode.DEFAULT
|
|
||||||
].repo_id
|
|
||||||
== "kernels-community/activation"
|
== "kernels-community/activation"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -303,174 +205,20 @@ def test_validate_kernel_layer():
|
|||||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_mode_for_mapping_rejected():
|
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
||||||
|
|
||||||
with use_kernel_mapping(
|
|
||||||
{
|
|
||||||
"Linear": {
|
|
||||||
"cuda": {
|
|
||||||
Mode.TRAINING: LayerRepository(
|
|
||||||
repo_id="kernels-test/backward-marker-test",
|
|
||||||
layer_name="LinearNoBackward",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
):
|
|
||||||
with pytest.raises(ValueError, match="does not support backward"):
|
|
||||||
kernelize(linear, mode=Mode.TRAINING)
|
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_modes():
|
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
|
||||||
|
|
||||||
# Case 1: layer without further specification, becomes the
|
|
||||||
# base layer.
|
|
||||||
with use_kernel_mapping(
|
|
||||||
{
|
|
||||||
"Linear": {
|
|
||||||
"cuda": LayerRepository(
|
|
||||||
repo_id="kernels-test/backward-marker-test",
|
|
||||||
layer_name="LinearBackward",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
):
|
|
||||||
kernelize(linear, mode=Mode.INFERENCE)
|
|
||||||
X = torch.randn(10, 32, device="cuda")
|
|
||||||
linear(X)
|
|
||||||
assert linear.n_calls == 0
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING)
|
|
||||||
linear(X)
|
|
||||||
assert linear.n_calls == 0
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
||||||
linear(X)
|
|
||||||
assert linear.n_calls == 0
|
|
||||||
|
|
||||||
# Case 2: register a kernel just for training. If no base kernel
|
|
||||||
# layer is registered, we fall back to the original layer.
|
|
||||||
with use_kernel_mapping(
|
|
||||||
{
|
|
||||||
"Linear": {
|
|
||||||
"cuda": {
|
|
||||||
Mode.TRAINING: LayerRepository(
|
|
||||||
repo_id="kernels-test/backward-marker-test",
|
|
||||||
layer_name="LinearBackward",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
):
|
|
||||||
kernelize(linear, mode=Mode.INFERENCE)
|
|
||||||
X = torch.randn(10, 32, device="cuda")
|
|
||||||
linear(X)
|
|
||||||
assert linear.n_calls == 1
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING)
|
|
||||||
linear(X)
|
|
||||||
# Training has a kernel, so fallback.
|
|
||||||
assert linear.n_calls == 1
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
||||||
linear(X)
|
|
||||||
# No kernel for training + torch.compile, so fallback.
|
|
||||||
assert linear.n_calls == 2
|
|
||||||
|
|
||||||
# Case 3: register a kernel just for training and one for fallback.
|
|
||||||
with use_kernel_mapping(
|
|
||||||
{
|
|
||||||
"Linear": {
|
|
||||||
"cuda": {
|
|
||||||
Mode.DEFAULT: LayerRepository(
|
|
||||||
repo_id="kernels-test/backward-marker-test",
|
|
||||||
layer_name="LinearBackward",
|
|
||||||
),
|
|
||||||
Mode.TRAINING: LayerRepository(
|
|
||||||
repo_id="kernels-test/backward-marker-test",
|
|
||||||
layer_name="LinearBackward",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
):
|
|
||||||
kernelize(linear, mode=Mode.INFERENCE)
|
|
||||||
X = torch.randn(10, 32, device="cuda")
|
|
||||||
linear(X)
|
|
||||||
# Uses the base kernel.
|
|
||||||
assert linear.n_calls == 2
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING)
|
|
||||||
linear(X)
|
|
||||||
# Uses the training kernel.
|
|
||||||
assert linear.n_calls == 2
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
||||||
linear(X)
|
|
||||||
# Uses the base kernel.
|
|
||||||
assert linear.n_calls == 2
|
|
||||||
|
|
||||||
# Case 4: register a kernel with two preferences.
|
|
||||||
with use_kernel_mapping(
|
|
||||||
{
|
|
||||||
"Linear": {
|
|
||||||
"cuda": {
|
|
||||||
Mode.TRAINING
|
|
||||||
| Mode.TORCH_COMPILE: LayerRepository(
|
|
||||||
repo_id="kernels-test/backward-marker-test",
|
|
||||||
layer_name="LinearBackward",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
):
|
|
||||||
kernelize(linear, mode=Mode.INFERENCE)
|
|
||||||
X = torch.randn(10, 32, device="cuda")
|
|
||||||
linear(X)
|
|
||||||
# No inference kernel, so fallback.
|
|
||||||
assert linear.n_calls == 3
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING)
|
|
||||||
linear(X)
|
|
||||||
# No training kernel, so fallback.
|
|
||||||
assert linear.n_calls == 4
|
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
|
||||||
linear(X)
|
|
||||||
# We do have a training + torch.compile kernel.
|
|
||||||
assert linear.n_calls == 4
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
|
||||||
def test_fallback_used_when_training():
|
def test_fallback_used_when_training():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
@use_kernel_forward_from_hub("Linear")
|
||||||
|
class TorchLinear(nn.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Used to check that we called hub kernel.
|
||||||
|
self.n_calls = 0
|
||||||
|
|
||||||
# Case 1: kernel with explicit backward support should always
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
# use the kernel.
|
self.n_calls += 1
|
||||||
with use_kernel_mapping(
|
return super().forward(input)
|
||||||
{
|
|
||||||
"Linear": {
|
|
||||||
Device(type="cuda"): LayerRepository(
|
|
||||||
repo_id="kernels-test/backward-marker-test",
|
|
||||||
layer_name="LinearBackward",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
):
|
|
||||||
linear.train()
|
|
||||||
kernelize(linear, mode=Mode.INFERENCE)
|
|
||||||
X = torch.randn(10, 32, device="cuda")
|
|
||||||
linear(X)
|
|
||||||
assert linear.n_calls == 0
|
|
||||||
|
|
||||||
linear.eval()
|
linear = TorchLinear(32, 32).to("cuda")
|
||||||
linear(X)
|
|
||||||
assert linear.n_calls == 0
|
|
||||||
|
|
||||||
# Case 2: kernel with implicit backward support should always
|
|
||||||
# use the kernel.
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Linear": {
|
"Linear": {
|
||||||
@ -482,7 +230,6 @@ def test_fallback_used_when_training():
|
|||||||
}
|
}
|
||||||
):
|
):
|
||||||
linear.train()
|
linear.train()
|
||||||
kernelize(linear, mode=Mode.INFERENCE)
|
|
||||||
X = torch.randn(10, 32, device="cuda")
|
X = torch.randn(10, 32, device="cuda")
|
||||||
linear(X)
|
linear(X)
|
||||||
assert linear.n_calls == 0
|
assert linear.n_calls == 0
|
||||||
@ -491,18 +238,40 @@ def test_fallback_used_when_training():
|
|||||||
linear(X)
|
linear(X)
|
||||||
assert linear.n_calls == 0
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
def test_invalid_mode_rejected():
|
{
|
||||||
with pytest.raises(ValueError, match="mutually exclusive"):
|
"Linear": {
|
||||||
_ = Mode.INFERENCE | Mode.TRAINING
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
with pytest.raises(ValueError, match="cannot be combined with other modes"):
|
layer_name="LinearBackward",
|
||||||
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
|
)
|
||||||
|
}
|
||||||
with pytest.raises(
|
}
|
||||||
ValueError, match="can only be used to register kernel mappings"
|
|
||||||
):
|
):
|
||||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
|
linear.train()
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="mode must contain"):
|
linear.eval()
|
||||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearNoBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
linear.train()
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
|
linear.eval()
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user