Compare commits

...

1 Commits

Author SHA1 Message Date
a988871e9e Add support for stateful layers 2025-09-19 10:14:43 +00:00
2 changed files with 113 additions and 17 deletions

View File

@ -17,8 +17,10 @@ from types import MethodType, ModuleType
from typing import (
TYPE_CHECKING,
Dict,
Mapping,
Optional,
Protocol,
Set,
Tuple,
Type,
Union,
@ -868,10 +870,14 @@ def kernelize(
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
if device is None:
device_type = _find_device(model)
device = _find_device(model)
device_type = _find_device_type(model)
elif isinstance(device, str):
_validate_device_type(device)
import torch
device_type = Device(type=device)
device = torch.device(device)
else:
device_type = Device(device.type)
@ -884,7 +890,7 @@ def kernelize(
layer_name = module_class.kernel_layer_name
if _DISABLE_KERNEL_MAPPING:
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
@ -898,7 +904,7 @@ def kernelize(
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue
# Get kernel options for the device
@ -909,7 +915,7 @@ def kernelize(
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue
repos = property_repos.repos
@ -919,7 +925,7 @@ def kernelize(
raise ValueError(
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
)
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue
repo_with_mode = _select_repository(
@ -932,7 +938,7 @@ def kernelize(
raise ValueError(
f"No repository for `{layer_name}` for configuration mode={mode}"
)
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue
repo, repo_mode = repo_with_mode
@ -951,6 +957,7 @@ def kernelize(
)
_conditionally_replace_forward(
device=device,
module=module,
layer=layer,
mode=mode,
@ -1037,19 +1044,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
raise TypeError(f"{repo} must not override nn.Module constructor.")
# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
difference = cls_members - torch_module_members
unique_members = _unique_layer_members(cls)
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
if not unique_members <= {
"can_torch_compile",
"create_state",
"has_backward",
"forward_with_state",
}:
raise TypeError(
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
)
# Check whether the forward signatures are similar.
params = inspect.signature(cls.forward).parameters
ref_params = inspect.signature(check_cls.forward).parameters
params: Mapping[str, inspect.Parameter]
if _is_stateful_layer(cls):
params = inspect.signature(cls.forward_with_state).parameters
# Get rid of the mappingproxy.
params = params.copy()
# Remove the state to be able to compare with forward.
del params["state"]
else:
params = inspect.signature(cls.forward).parameters
if len(params) != len(ref_params):
raise TypeError(
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
@ -1074,7 +1093,7 @@ def _is_rocm_platform():
return torch.version.hip is not None
def _find_device(model: "nn.Module") -> Device:
def _find_device(model: "nn.Module") -> torch.device:
try:
param = next(model.parameters())
except StopIteration:
@ -1082,7 +1101,13 @@ def _find_device(model: "nn.Module") -> Device:
"Cannot determine model device, provide as `device` argument to `kernelize`."
)
dev_type = param.device.type
return param.device
def _find_device_type(model: "nn.Module") -> Device:
device = _find_device(model)
dev_type = device.type
if dev_type == "cuda":
# Refine based on actual platform
if _is_rocm_platform():
@ -1103,6 +1128,7 @@ def _find_capability() -> int:
def _conditionally_replace_forward(
*,
device: "torch.device",
module: "nn.Module",
layer: Type["nn.Module"],
mode: Mode,
@ -1128,15 +1154,25 @@ def _conditionally_replace_forward(
logging.info("Layer does not support torch.compile, using fallback")
if needs_fallback_for_backward:
logging.info("Layer does not support backward, using fallback")
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")
else:
_replace_forward(module, layer)
_replace_forward(device, module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
def _replace_forward(
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
):
if _is_stateful_layer(layer):
state = layer.create_state(device, module) # type: ignore[attr-defined]
def forward(self, *args, **kwargs):
return layer.forward_with_state(self, state, *args, **kwargs)
module.forward = MethodType(forward, module)
else:
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
def _validate_layer_has_mode(
@ -1179,3 +1215,21 @@ def _get_layer_memoize(
_CACHED_LAYER[repo] = layer
return layer
def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
import torch.nn as nn
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(layer)}
return cls_members - torch_module_members
def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
unique = _unique_layer_members(layer)
is_stateful = "forward_with_state" in unique
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
raise TypeError(
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
)
return is_stateful

View File

@ -5,6 +5,7 @@ import pytest
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.testing import assert_close
from kernels import (
CUDAProperties,
@ -321,6 +322,47 @@ def test_local_layer_repo(device):
assert linear.n_calls == 0
def test_stateful_layer(device):
@use_kernel_forward_from_hub("ReluWithHiddenSize")
class ReluWithHiddenSize(nn.Module):
hidden_size: int
def __init__(self, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x)
model = ReluWithHiddenSize(hidden_size=64).to(device)
x = torch.randn((32, 64), device=device)
y_ref = model(x)
with use_kernel_mapping(
{
"ReluWithHiddenSize": {
"cuda": LayerRepository(
repo_id="kernels-test/state-test",
layer_name="StatefulReLU",
),
"xpu": LayerRepository(
repo_id="kernels-test/state-test",
layer_name="StatefulReLU",
),
}
},
inherit_mapping=False,
):
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
y = model(x)
assert_close(y, y_ref)
model = torch.compile(model, fullgraph=True)
y = model(x)
assert_close(y, y_ref)
@pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])