mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 13:33:48 +08:00
Compare commits
1 Commits
v0.10.4
...
stateful-l
Author | SHA1 | Date | |
---|---|---|---|
a988871e9e |
@ -17,8 +17,10 @@ from types import MethodType, ModuleType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@ -868,10 +870,14 @@ def kernelize(
|
||||
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
|
||||
|
||||
if device is None:
|
||||
device_type = _find_device(model)
|
||||
device = _find_device(model)
|
||||
device_type = _find_device_type(model)
|
||||
elif isinstance(device, str):
|
||||
_validate_device_type(device)
|
||||
import torch
|
||||
|
||||
device_type = Device(type=device)
|
||||
device = torch.device(device)
|
||||
else:
|
||||
device_type = Device(device.type)
|
||||
|
||||
@ -884,7 +890,7 @@ def kernelize(
|
||||
layer_name = module_class.kernel_layer_name
|
||||
|
||||
if _DISABLE_KERNEL_MAPPING:
|
||||
_replace_forward(module, module_class)
|
||||
_replace_forward(device, module, module_class)
|
||||
continue
|
||||
|
||||
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
|
||||
@ -898,7 +904,7 @@ def kernelize(
|
||||
)
|
||||
if not use_fallback:
|
||||
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||
_replace_forward(module, module_class)
|
||||
_replace_forward(device, module, module_class)
|
||||
continue
|
||||
|
||||
# Get kernel options for the device
|
||||
@ -909,7 +915,7 @@ def kernelize(
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
||||
)
|
||||
_replace_forward(module, module_class)
|
||||
_replace_forward(device, module, module_class)
|
||||
continue
|
||||
|
||||
repos = property_repos.repos
|
||||
@ -919,7 +925,7 @@ def kernelize(
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
|
||||
)
|
||||
_replace_forward(module, module_class)
|
||||
_replace_forward(device, module, module_class)
|
||||
continue
|
||||
|
||||
repo_with_mode = _select_repository(
|
||||
@ -932,7 +938,7 @@ def kernelize(
|
||||
raise ValueError(
|
||||
f"No repository for `{layer_name}` for configuration mode={mode}"
|
||||
)
|
||||
_replace_forward(module, module_class)
|
||||
_replace_forward(device, module, module_class)
|
||||
continue
|
||||
|
||||
repo, repo_mode = repo_with_mode
|
||||
@ -951,6 +957,7 @@ def kernelize(
|
||||
)
|
||||
|
||||
_conditionally_replace_forward(
|
||||
device=device,
|
||||
module=module,
|
||||
layer=layer,
|
||||
mode=mode,
|
||||
@ -1037,19 +1044,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
|
||||
raise TypeError(f"{repo} must not override nn.Module constructor.")
|
||||
|
||||
# ... or predefined member variables.
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
cls_members = {name for name, _ in inspect.getmembers(cls)}
|
||||
difference = cls_members - torch_module_members
|
||||
unique_members = _unique_layer_members(cls)
|
||||
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
|
||||
if not difference <= {"can_torch_compile", "has_backward"}:
|
||||
if not unique_members <= {
|
||||
"can_torch_compile",
|
||||
"create_state",
|
||||
"has_backward",
|
||||
"forward_with_state",
|
||||
}:
|
||||
raise TypeError(
|
||||
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
|
||||
)
|
||||
|
||||
# Check whether the forward signatures are similar.
|
||||
params = inspect.signature(cls.forward).parameters
|
||||
ref_params = inspect.signature(check_cls.forward).parameters
|
||||
|
||||
params: Mapping[str, inspect.Parameter]
|
||||
if _is_stateful_layer(cls):
|
||||
params = inspect.signature(cls.forward_with_state).parameters
|
||||
# Get rid of the mappingproxy.
|
||||
params = params.copy()
|
||||
# Remove the state to be able to compare with forward.
|
||||
del params["state"]
|
||||
else:
|
||||
params = inspect.signature(cls.forward).parameters
|
||||
|
||||
if len(params) != len(ref_params):
|
||||
raise TypeError(
|
||||
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
|
||||
@ -1074,7 +1093,7 @@ def _is_rocm_platform():
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def _find_device(model: "nn.Module") -> Device:
|
||||
def _find_device(model: "nn.Module") -> torch.device:
|
||||
try:
|
||||
param = next(model.parameters())
|
||||
except StopIteration:
|
||||
@ -1082,7 +1101,13 @@ def _find_device(model: "nn.Module") -> Device:
|
||||
"Cannot determine model device, provide as `device` argument to `kernelize`."
|
||||
)
|
||||
|
||||
dev_type = param.device.type
|
||||
return param.device
|
||||
|
||||
|
||||
def _find_device_type(model: "nn.Module") -> Device:
|
||||
device = _find_device(model)
|
||||
|
||||
dev_type = device.type
|
||||
if dev_type == "cuda":
|
||||
# Refine based on actual platform
|
||||
if _is_rocm_platform():
|
||||
@ -1103,6 +1128,7 @@ def _find_capability() -> int:
|
||||
|
||||
def _conditionally_replace_forward(
|
||||
*,
|
||||
device: "torch.device",
|
||||
module: "nn.Module",
|
||||
layer: Type["nn.Module"],
|
||||
mode: Mode,
|
||||
@ -1128,15 +1154,25 @@ def _conditionally_replace_forward(
|
||||
logging.info("Layer does not support torch.compile, using fallback")
|
||||
if needs_fallback_for_backward:
|
||||
logging.info("Layer does not support backward, using fallback")
|
||||
_replace_forward(module, module_class)
|
||||
_replace_forward(device, module, module_class)
|
||||
else:
|
||||
raise ValueError(f"Available kernel does not support mode: {mode}")
|
||||
else:
|
||||
_replace_forward(module, layer)
|
||||
_replace_forward(device, module, layer)
|
||||
|
||||
|
||||
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
|
||||
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
||||
def _replace_forward(
|
||||
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
|
||||
):
|
||||
if _is_stateful_layer(layer):
|
||||
state = layer.create_state(device, module) # type: ignore[attr-defined]
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return layer.forward_with_state(self, state, *args, **kwargs)
|
||||
|
||||
module.forward = MethodType(forward, module)
|
||||
else:
|
||||
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _validate_layer_has_mode(
|
||||
@ -1179,3 +1215,21 @@ def _get_layer_memoize(
|
||||
_CACHED_LAYER[repo] = layer
|
||||
|
||||
return layer
|
||||
|
||||
|
||||
def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
|
||||
import torch.nn as nn
|
||||
|
||||
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
|
||||
cls_members = {name for name, _ in inspect.getmembers(layer)}
|
||||
return cls_members - torch_module_members
|
||||
|
||||
|
||||
def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
|
||||
unique = _unique_layer_members(layer)
|
||||
is_stateful = "forward_with_state" in unique
|
||||
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
|
||||
raise TypeError(
|
||||
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
|
||||
)
|
||||
return is_stateful
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.testing import assert_close
|
||||
|
||||
from kernels import (
|
||||
CUDAProperties,
|
||||
@ -321,6 +322,47 @@ def test_local_layer_repo(device):
|
||||
assert linear.n_calls == 0
|
||||
|
||||
|
||||
def test_stateful_layer(device):
|
||||
@use_kernel_forward_from_hub("ReluWithHiddenSize")
|
||||
class ReluWithHiddenSize(nn.Module):
|
||||
hidden_size: int
|
||||
|
||||
def __init__(self, hidden_size: int):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu(x)
|
||||
|
||||
model = ReluWithHiddenSize(hidden_size=64).to(device)
|
||||
x = torch.randn((32, 64), device=device)
|
||||
y_ref = model(x)
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"ReluWithHiddenSize": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-test/state-test",
|
||||
layer_name="StatefulReLU",
|
||||
),
|
||||
"xpu": LayerRepository(
|
||||
repo_id="kernels-test/state-test",
|
||||
layer_name="StatefulReLU",
|
||||
),
|
||||
}
|
||||
},
|
||||
inherit_mapping=False,
|
||||
):
|
||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
|
||||
|
||||
y = model(x)
|
||||
assert_close(y, y_ref)
|
||||
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
y = model(x)
|
||||
assert_close(y, y_ref)
|
||||
|
||||
|
||||
@pytest.mark.cuda_only
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
|
Reference in New Issue
Block a user