mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 21:38:52 +08:00
Compare commits
1 Commits
v0.10.4
...
kernelize-
Author | SHA1 | Date | |
---|---|---|---|
b37e3d468a |
@ -49,13 +49,16 @@ A model will not use Hub kernels by default, even if it contains extensible
|
||||
layers. To enable the use of Hub kernels in the model, it needs to be
|
||||
'kernelized' using the `kernelize` function. This function traverses the
|
||||
model graph and replaces the `forward` methods of extensible layers for which
|
||||
Hub kernels are registered. Kernelize can be used as follows:
|
||||
Hub kernels are registered. `kernelize` can be used as follows:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
The `kernelize` function modifies the model in-place, the model
|
||||
itself is returned as a convenience.
|
||||
|
||||
The `mode` specifies that the model will be used in inference. Similarly,
|
||||
you can ask `kernelize` to prepare the model for training:
|
||||
|
||||
@ -64,8 +67,11 @@ model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.TRAINING)
|
||||
```
|
||||
|
||||
**Note:** the `kernelize` function modifies the model in-place, the model
|
||||
itself is returned as a convenience.
|
||||
When the `mode` argument is not specified, the
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` mode is used as the default. This mode
|
||||
aligns most closely with pure PyTorch layers (which generally support backward
|
||||
passes and `torch.compile`). However, this mode can also lead to fewer
|
||||
kernels being used, since not all kernels support training or `torch.compile`.
|
||||
|
||||
### Kernel device
|
||||
|
||||
@ -196,6 +202,31 @@ In this case, modes other than `Mode.INFERENCE` and
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
|
||||
`kernels-community/activation`.
|
||||
|
||||
### Mode fallback behavior
|
||||
|
||||
As described above, if there is no exact match for the mode given to
|
||||
`kernelize`, it will try to use the kernel registered for `Mode.DEFAULT`.
|
||||
If the `Mode.DEFAULT` kernel does not support the `kernelize` mode, the
|
||||
original layer's `forward` method will be used instead.
|
||||
|
||||
As an example, suppose that two kernels were registered for a layer:
|
||||
|
||||
1. Kernel `A` is registered for `Mode.DEFAULT`. This kernel supports training
|
||||
(backward), but not `torch.compile`.
|
||||
2. Kernel `B` is registered for `Mode.INFERENCE | Mode.COMPILE` and supports
|
||||
`torch.compile`.
|
||||
|
||||
`kernelize` modes will then behave as follows:
|
||||
|
||||
- `Mode.INFERENCE | Mode.COMPILE`` uses kernel `B`: exact match.
|
||||
- `Mode.INFERENCE` uses kernel `A`: no exact match, so fall back to
|
||||
`Mode.DEFAULT`.
|
||||
- `Mode.TRAIN` uses kernel `A`: no exact match, so fall back to
|
||||
`Mode.DEFAULT`, which supports training.
|
||||
- `Mode.TRAIN | Mode.COMPILE`: uses the original layer's
|
||||
`forward`: no exact match, falling back to `Mode.DEFAULT` is not possible
|
||||
because kernel `A` does not support `torch.compile`.
|
||||
|
||||
### Registering kernels for specific CUDA capabilities
|
||||
|
||||
Some kernels only work with newer CUDA architectures. For instance, some
|
||||
|
@ -325,7 +325,7 @@ def _select_repository(
|
||||
def kernelize(
|
||||
model: "nn.Module",
|
||||
*,
|
||||
mode: Mode,
|
||||
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
use_fallback: bool = True,
|
||||
):
|
||||
|
@ -400,6 +400,11 @@ def test_kernel_modes():
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 2: register a kernel just for training. If no base kernel
|
||||
# layer is registered, we fall back to the original layer.
|
||||
with use_kernel_mapping(
|
||||
@ -429,6 +434,12 @@ def test_kernel_modes():
|
||||
# No kernel for training + torch.compile, so fallback.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# No kernel for training + torch.compile, so fallback.
|
||||
assert linear.n_calls == 3
|
||||
|
||||
# Case 3: register a kernel just for training and one for fallback.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
@ -450,17 +461,23 @@ def test_kernel_modes():
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 3
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Uses the training kernel.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 3
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
assert linear.n_calls == 2
|
||||
assert linear.n_calls == 3
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
assert linear.n_calls == 3
|
||||
|
||||
# Case 4: register a kernel with two preferences.
|
||||
with use_kernel_mapping(
|
||||
@ -480,17 +497,22 @@ def test_kernel_modes():
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# No inference kernel, so fallback.
|
||||
assert linear.n_calls == 3
|
||||
assert linear.n_calls == 4
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# No training kernel, so fallback.
|
||||
assert linear.n_calls == 4
|
||||
assert linear.n_calls == 5
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# We do have a training + torch.compile kernel.
|
||||
assert linear.n_calls == 4
|
||||
assert linear.n_calls == 5
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
assert linear.n_calls == 5
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
|
Reference in New Issue
Block a user