Do not use kernels without backward when training (#68)

* Do not use kernels without backward when training

* Update repo for backwards marker test
This commit is contained in:
Daniël de Kok
2025-04-11 10:05:57 +02:00
committed by GitHub
parent 6fd2112e22
commit 1d14abcef0
3 changed files with 93 additions and 9 deletions

View File

@ -119,10 +119,17 @@ requirements:
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
The only exception to the _no class variables rule_ is addition of a
`has_backward` class variable. This variable is used to indicate whether
the layer has a backward pass implemented (`True` when absent).
This is an example of a pure layer:
```python
class SiluAndMul(nn.Module):
# This layer does not implement backward.
has_backward: bool = False
def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)

View File

@ -4,7 +4,7 @@ import warnings
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Dict, Union
from typing import TYPE_CHECKING, Dict, Union
from .utils import get_kernel
@ -131,12 +131,13 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
fallback_forward = cls.forward
cached_forward: Dict[LayerRepository, Callable] = {}
cached_layer: Dict[LayerRepository, nn.Module] = {}
def forward(self, x, *args, **kwargs):
if _DISABLE_KERNEL_MAPPING:
return fallback_forward(self, x, *args, **kwargs)
needs_backward = self.training
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
warnings.warn(
@ -162,9 +163,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
return fallback_forward(self, x, *args, **kwargs)
# Short-circuit if we already loaded the layer.
layer_forward = cached_forward.get(repo, None)
if layer_forward is not None:
return layer_forward(self, x, *args, **kwargs)
layer = cached_layer.get(repo, None)
if layer is not None:
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
layer = _get_kernel_layer(
repo_id=repo.repo_id,
@ -180,10 +183,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
finally:
cls.forward = orig_forward
layer_forward = layer.forward
cached_forward[repo] = layer_forward
cached_layer[repo] = layer
return layer_forward(self, x, *args, **kwargs)
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
cls.forward = forward
@ -240,7 +244,8 @@ def _validate_layer(*, check_cls, cls):
# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
if cls_members - torch_module_members != set():
difference = cls_members - torch_module_members
if difference != set() and difference != {"has_backward"}:
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.

View File

@ -203,3 +203,75 @@ def test_validate_kernel_layer():
with pytest.raises(TypeError, match="different kind of arguments"):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
def test_fallback_used_when_training():
@use_kernel_forward_from_hub("Linear")
class TorchLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
linear = TorchLinear(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
linear.eval()
linear(X)
assert linear.n_calls == 1