mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
Do not use kernels without backward when training (#68)
* Do not use kernels without backward when training * Update repo for backwards marker test
This commit is contained in:
@ -119,10 +119,17 @@ requirements:
|
|||||||
- The `forward` method has a signature that is compatible with the
|
- The `forward` method has a signature that is compatible with the
|
||||||
`forward` method that it is extending.
|
`forward` method that it is extending.
|
||||||
|
|
||||||
|
The only exception to the _no class variables rule_ is addition of a
|
||||||
|
`has_backward` class variable. This variable is used to indicate whether
|
||||||
|
the layer has a backward pass implemented (`True` when absent).
|
||||||
|
|
||||||
This is an example of a pure layer:
|
This is an example of a pure layer:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
|
# This layer does not implement backward.
|
||||||
|
has_backward: bool = False
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = x.shape[:-1] + (d,)
|
output_shape = x.shape[:-1] + (d,)
|
||||||
|
@ -4,7 +4,7 @@ import warnings
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Union
|
from typing import TYPE_CHECKING, Dict, Union
|
||||||
|
|
||||||
from .utils import get_kernel
|
from .utils import get_kernel
|
||||||
|
|
||||||
@ -131,12 +131,13 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
|||||||
|
|
||||||
fallback_forward = cls.forward
|
fallback_forward = cls.forward
|
||||||
|
|
||||||
cached_forward: Dict[LayerRepository, Callable] = {}
|
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
if _DISABLE_KERNEL_MAPPING:
|
if _DISABLE_KERNEL_MAPPING:
|
||||||
return fallback_forward(self, x, *args, **kwargs)
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
|
|
||||||
|
needs_backward = self.training
|
||||||
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||||
if kernel is None:
|
if kernel is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -162,9 +163,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
|||||||
return fallback_forward(self, x, *args, **kwargs)
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
|
|
||||||
# Short-circuit if we already loaded the layer.
|
# Short-circuit if we already loaded the layer.
|
||||||
layer_forward = cached_forward.get(repo, None)
|
layer = cached_layer.get(repo, None)
|
||||||
if layer_forward is not None:
|
if layer is not None:
|
||||||
return layer_forward(self, x, *args, **kwargs)
|
if needs_backward and not getattr(layer, "has_backward", True):
|
||||||
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
|
return layer.forward(self, x, *args, **kwargs)
|
||||||
|
|
||||||
layer = _get_kernel_layer(
|
layer = _get_kernel_layer(
|
||||||
repo_id=repo.repo_id,
|
repo_id=repo.repo_id,
|
||||||
@ -180,10 +183,11 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
|||||||
finally:
|
finally:
|
||||||
cls.forward = orig_forward
|
cls.forward = orig_forward
|
||||||
|
|
||||||
layer_forward = layer.forward
|
cached_layer[repo] = layer
|
||||||
cached_forward[repo] = layer_forward
|
|
||||||
|
|
||||||
return layer_forward(self, x, *args, **kwargs)
|
if needs_backward and not getattr(layer, "has_backward", True):
|
||||||
|
return fallback_forward(self, x, *args, **kwargs)
|
||||||
|
return layer.forward(self, x, *args, **kwargs)
|
||||||
|
|
||||||
cls.forward = forward
|
cls.forward = forward
|
||||||
|
|
||||||
@ -240,7 +244,8 @@ def _validate_layer(*, check_cls, cls):
|
|||||||
# ... or predefined member variables.
|
# ... or predefined member variables.
|
||||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||||
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||||
if cls_members - torch_module_members != set():
|
difference = cls_members - torch_module_members
|
||||||
|
if difference != set() and difference != {"has_backward"}:
|
||||||
raise TypeError("Layer must not contain additional members.")
|
raise TypeError("Layer must not contain additional members.")
|
||||||
|
|
||||||
# Check whether the forward signatures are similar.
|
# Check whether the forward signatures are similar.
|
||||||
|
@ -203,3 +203,75 @@ def test_validate_kernel_layer():
|
|||||||
|
|
||||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
with pytest.raises(TypeError, match="different kind of arguments"):
|
||||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_used_when_training():
|
||||||
|
@use_kernel_forward_from_hub("Linear")
|
||||||
|
class TorchLinear(nn.Linear):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Used to check that we called hub kernel.
|
||||||
|
self.n_calls = 0
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
self.n_calls += 1
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
linear = TorchLinear(32, 32).to("cuda")
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearImplicitBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
linear.train()
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
linear.eval()
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
linear.train()
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
linear.eval()
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
Device(type="cuda"): LayerRepository(
|
||||||
|
repo_id="kernels-test/backward-marker-test",
|
||||||
|
layer_name="LinearNoBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
):
|
||||||
|
linear.train()
|
||||||
|
X = torch.randn(10, 32, device="cuda")
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
|
linear.eval()
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 1
|
||||||
|
Reference in New Issue
Block a user