diff --git a/docs/kernel-requirements.md b/docs/kernel-requirements.md index 4efdcdc..a02550b 100644 --- a/docs/kernel-requirements.md +++ b/docs/kernel-requirements.md @@ -119,10 +119,17 @@ requirements: - The `forward` method has a signature that is compatible with the `forward` method that it is extending. +The only exception to the _no class variables rule_ is addition of a +`has_backward` class variable. This variable is used to indicate whether +the layer has a backward pass implemented (`True` when absent). + This is an example of a pure layer: ```python class SiluAndMul(nn.Module): + # This layer does not implement backward. + has_backward: bool = False + def forward(self, x: torch.Tensor): d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 77d157c..51901d6 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -4,7 +4,7 @@ import warnings from contextvars import ContextVar from copy import deepcopy from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, Dict, Union +from typing import TYPE_CHECKING, Dict, Union from .utils import get_kernel @@ -131,12 +131,13 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool fallback_forward = cls.forward - cached_forward: Dict[LayerRepository, Callable] = {} + cached_layer: Dict[LayerRepository, nn.Module] = {} def forward(self, x, *args, **kwargs): if _DISABLE_KERNEL_MAPPING: return fallback_forward(self, x, *args, **kwargs) + needs_backward = self.training kernel = _KERNEL_MAPPING.get().get(layer_name) if kernel is None: warnings.warn( @@ -162,9 +163,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool return fallback_forward(self, x, *args, **kwargs) # Short-circuit if we already loaded the layer. - layer_forward = cached_forward.get(repo, None) - if layer_forward is not None: - return layer_forward(self, x, *args, **kwargs) + layer = cached_layer.get(repo, None) + if layer is not None: + if needs_backward and not getattr(layer, "has_backward", True): + return fallback_forward(self, x, *args, **kwargs) + return layer.forward(self, x, *args, **kwargs) layer = _get_kernel_layer( repo_id=repo.repo_id, @@ -180,10 +183,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool finally: cls.forward = orig_forward - layer_forward = layer.forward - cached_forward[repo] = layer_forward + cached_layer[repo] = layer - return layer_forward(self, x, *args, **kwargs) + if needs_backward and not getattr(layer, "has_backward", True): + return fallback_forward(self, x, *args, **kwargs) + return layer.forward(self, x, *args, **kwargs) cls.forward = forward @@ -240,7 +244,8 @@ def _validate_layer(*, check_cls, cls): # ... or predefined member variables. torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} cls_members = {name for name, _ in inspect.getmembers(cls)} - if cls_members - torch_module_members != set(): + difference = cls_members - torch_module_members + if difference != set() and difference != {"has_backward"}: raise TypeError("Layer must not contain additional members.") # Check whether the forward signatures are similar. diff --git a/tests/test_layer.py b/tests/test_layer.py index 2eeffbf..cc7768c 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -203,3 +203,75 @@ def test_validate_kernel_layer(): with pytest.raises(TypeError, match="different kind of arguments"): _validate_layer(cls=BadLayer4, check_cls=SiluAndMul) + + +def test_fallback_used_when_training(): + @use_kernel_forward_from_hub("Linear") + class TorchLinear(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Used to check that we called hub kernel. + self.n_calls = 0 + + def forward(self, input: torch.Tensor) -> torch.Tensor: + self.n_calls += 1 + return super().forward(input) + + linear = TorchLinear(32, 32).to("cuda") + + with use_kernel_mapping( + { + "Linear": { + Device(type="cuda"): LayerRepository( + repo_id="kernels-test/backward-marker-test", + layer_name="LinearImplicitBackward", + ) + } + } + ): + linear.train() + X = torch.randn(10, 32, device="cuda") + linear(X) + assert linear.n_calls == 0 + + linear.eval() + linear(X) + assert linear.n_calls == 0 + + with use_kernel_mapping( + { + "Linear": { + Device(type="cuda"): LayerRepository( + repo_id="kernels-test/backward-marker-test", + layer_name="LinearBackward", + ) + } + } + ): + linear.train() + X = torch.randn(10, 32, device="cuda") + linear(X) + assert linear.n_calls == 0 + + linear.eval() + linear(X) + assert linear.n_calls == 0 + + with use_kernel_mapping( + { + "Linear": { + Device(type="cuda"): LayerRepository( + repo_id="kernels-test/backward-marker-test", + layer_name="LinearNoBackward", + ) + } + } + ): + linear.train() + X = torch.randn(10, 32, device="cuda") + linear(X) + assert linear.n_calls == 1 + + linear.eval() + linear(X) + assert linear.n_calls == 1