Compare commits

...

1 Commits

Author SHA1 Message Date
b37e3d468a Set kernelize default mode to Mode.TRAINING | Mode.TORCH_COMPILE
Also update docs and tests.
2025-07-15 08:44:13 +00:00
3 changed files with 63 additions and 10 deletions

View File

@ -49,13 +49,16 @@ A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. Kernelize can be used as follows:
Hub kernels are registered. `kernelize` can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
The `mode` specifies that the model will be used in inference. Similarly,
you can ask `kernelize` to prepare the model for training:
@ -64,8 +67,11 @@ model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
When the `mode` argument is not specified, the
`Mode.TRAINING | Mode.TORCH_COMPILE` mode is used as the default. This mode
aligns most closely with pure PyTorch layers (which generally support backward
passes and `torch.compile`). However, this mode can also lead to fewer
kernels being used, since not all kernels support training or `torch.compile`.
### Kernel device
@ -196,6 +202,31 @@ In this case, modes other than `Mode.INFERENCE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
`kernels-community/activation`.
### Mode fallback behavior
As described above, if there is no exact match for the mode given to
`kernelize`, it will try to use the kernel registered for `Mode.DEFAULT`.
If the `Mode.DEFAULT` kernel does not support the `kernelize` mode, the
original layer's `forward` method will be used instead.
As an example, suppose that two kernels were registered for a layer:
1. Kernel `A` is registered for `Mode.DEFAULT`. This kernel supports training
(backward), but not `torch.compile`.
2. Kernel `B` is registered for `Mode.INFERENCE | Mode.COMPILE` and supports
`torch.compile`.
`kernelize` modes will then behave as follows:
- `Mode.INFERENCE | Mode.COMPILE`` uses kernel `B`: exact match.
- `Mode.INFERENCE` uses kernel `A`: no exact match, so fall back to
`Mode.DEFAULT`.
- `Mode.TRAIN` uses kernel `A`: no exact match, so fall back to
`Mode.DEFAULT`, which supports training.
- `Mode.TRAIN | Mode.COMPILE`: uses the original layer's
`forward`: no exact match, falling back to `Mode.DEFAULT` is not possible
because kernel `A` does not support `torch.compile`.
### Registering kernels for specific CUDA capabilities
Some kernels only work with newer CUDA architectures. For instance, some

View File

@ -325,7 +325,7 @@ def _select_repository(
def kernelize(
model: "nn.Module",
*,
mode: Mode,
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
device: Optional[Union[str, "torch.device"]] = None,
use_fallback: bool = True,
):

View File

@ -400,6 +400,11 @@ def test_kernel_modes():
linear(X)
assert linear.n_calls == 0
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer.
with use_kernel_mapping(
@ -429,6 +434,12 @@ def test_kernel_modes():
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 2
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 3
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
@ -450,17 +461,23 @@ def test_kernel_modes():
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses the training kernel.
assert linear.n_calls == 2
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
assert linear.n_calls == 3
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 3
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
@ -480,17 +497,22 @@ def test_kernel_modes():
X = torch.randn(10, 32, device="cuda")
linear(X)
# No inference kernel, so fallback.
assert linear.n_calls == 3
assert linear.n_calls == 4
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback.
assert linear.n_calls == 4
assert linear.n_calls == 5
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# We do have a training + torch.compile kernel.
assert linear.n_calls == 4
assert linear.n_calls == 5
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 5
@pytest.mark.linux_only