Compare commits

...

1 Commits

Author SHA1 Message Date
b287111821 Fix forward positional argument handling 2025-03-19 13:54:18 +00:00
2 changed files with 24 additions and 6 deletions

View File

@ -114,16 +114,16 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
cached_forward: Dict[LayerRepository, Callable] = {}
def forward(self, x, **args):
def forward(self, x, *args, **kwargs):
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
return fallback_forward(self, x, **args)
return fallback_forward(self, x, *args, **kwargs)
device = getattr(x, "device", None)
if device is None:
return fallback_forward(self, x, **args)
return fallback_forward(self, x, *args, **kwargs)
repo = kernel.get(Device(type=device.type))
if repo is None:
@ -131,12 +131,12 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
)
return fallback_forward(self, x, **args)
return fallback_forward(self, x, *args, **kwargs)
# Short-circuit if we already loaded the layer.
layer_forward = cached_forward.get(repo, None)
if layer_forward is not None:
return layer_forward(self, x, **args)
return layer_forward(self, x, *args, **kwargs)
layer = _get_kernel_layer(
repo_id=repo.repo_id,
@ -155,7 +155,7 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
layer_forward = layer.forward
cached_forward[repo] = layer_forward
return layer_forward(self, x, **args)
return layer_forward(self, x, *args, **kwargs)
cls.forward = forward

View File

@ -53,6 +53,24 @@ class SiluAndMulStringDevice(SiluAndMul):
pass
def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
def forward(
self,
arg1,
arg2,
*,
kwarg1,
kwarg2=42,
):
return (arg1, arg2, kwarg1, kwarg2)
arg_kind = ArgKind()
assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):