Do not use kernels without backward when training (#68)

* Do not use kernels without backward when training

* Update repo for backwards marker test
This commit is contained in:
Daniël de Kok
2025-04-11 10:05:57 +02:00
committed by GitHub
parent 6fd2112e22
commit 1d14abcef0
3 changed files with 93 additions and 9 deletions

View File

@ -119,10 +119,17 @@ requirements:
- The `forward` method has a signature that is compatible with the - The `forward` method has a signature that is compatible with the
`forward` method that it is extending. `forward` method that it is extending.
The only exception to the _no class variables rule_ is addition of a
`has_backward` class variable. This variable is used to indicate whether
the layer has a backward pass implemented (`True` when absent).
This is an example of a pure layer: This is an example of a pure layer:
```python ```python
class SiluAndMul(nn.Module): class SiluAndMul(nn.Module):
# This layer does not implement backward.
has_backward: bool = False
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,) output_shape = x.shape[:-1] + (d,)

View File

@ -4,7 +4,7 @@ import warnings
from contextvars import ContextVar from contextvars import ContextVar
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Dict, Union from typing import TYPE_CHECKING, Dict, Union
from .utils import get_kernel from .utils import get_kernel
@ -131,12 +131,13 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
fallback_forward = cls.forward fallback_forward = cls.forward
cached_forward: Dict[LayerRepository, Callable] = {} cached_layer: Dict[LayerRepository, nn.Module] = {}
def forward(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs):
if _DISABLE_KERNEL_MAPPING: if _DISABLE_KERNEL_MAPPING:
return fallback_forward(self, x, *args, **kwargs) return fallback_forward(self, x, *args, **kwargs)
needs_backward = self.training
kernel = _KERNEL_MAPPING.get().get(layer_name) kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None: if kernel is None:
warnings.warn( warnings.warn(
@ -162,9 +163,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
return fallback_forward(self, x, *args, **kwargs) return fallback_forward(self, x, *args, **kwargs)
# Short-circuit if we already loaded the layer. # Short-circuit if we already loaded the layer.
layer_forward = cached_forward.get(repo, None) layer = cached_layer.get(repo, None)
if layer_forward is not None: if layer is not None:
return layer_forward(self, x, *args, **kwargs) if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
layer = _get_kernel_layer( layer = _get_kernel_layer(
repo_id=repo.repo_id, repo_id=repo.repo_id,
@ -180,10 +183,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
finally: finally:
cls.forward = orig_forward cls.forward = orig_forward
layer_forward = layer.forward cached_layer[repo] = layer
cached_forward[repo] = layer_forward
return layer_forward(self, x, *args, **kwargs) if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
cls.forward = forward cls.forward = forward
@ -240,7 +244,8 @@ def _validate_layer(*, check_cls, cls):
# ... or predefined member variables. # ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)} cls_members = {name for name, _ in inspect.getmembers(cls)}
if cls_members - torch_module_members != set(): difference = cls_members - torch_module_members
if difference != set() and difference != {"has_backward"}:
raise TypeError("Layer must not contain additional members.") raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar. # Check whether the forward signatures are similar.

View File

@ -203,3 +203,75 @@ def test_validate_kernel_layer():
with pytest.raises(TypeError, match="different kind of arguments"): with pytest.raises(TypeError, match="different kind of arguments"):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul) _validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
def test_fallback_used_when_training():
@use_kernel_forward_from_hub("Linear")
class TorchLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
linear = TorchLinear(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
linear.eval()
linear(X)
assert linear.n_calls == 1