mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-22 05:48:52 +08:00
Compare commits
1 Commits
v0.10.1
...
fixup-arg-
Author | SHA1 | Date | |
---|---|---|---|
b287111821 |
@ -114,16 +114,16 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
|
||||
cached_forward: Dict[LayerRepository, Callable] = {}
|
||||
|
||||
def forward(self, x, **args):
|
||||
def forward(self, x, *args, **kwargs):
|
||||
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||
if kernel is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||
return fallback_forward(self, x, **args)
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
device = getattr(x, "device", None)
|
||||
if device is None:
|
||||
return fallback_forward(self, x, **args)
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
repo = kernel.get(Device(type=device.type))
|
||||
if repo is None:
|
||||
@ -131,12 +131,12 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
|
||||
)
|
||||
return fallback_forward(self, x, **args)
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
# Short-circuit if we already loaded the layer.
|
||||
layer_forward = cached_forward.get(repo, None)
|
||||
if layer_forward is not None:
|
||||
return layer_forward(self, x, **args)
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
@ -155,7 +155,7 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
layer_forward = layer.forward
|
||||
cached_forward[repo] = layer_forward
|
||||
|
||||
return layer_forward(self, x, **args)
|
||||
return layer_forward(self, x, *args, **kwargs)
|
||||
|
||||
cls.forward = forward
|
||||
|
||||
|
@ -53,6 +53,24 @@ class SiluAndMulStringDevice(SiluAndMul):
|
||||
pass
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@use_kernel_forward_from_hub("ArgKind")
|
||||
class ArgKind(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
arg1,
|
||||
arg2,
|
||||
*,
|
||||
kwarg1,
|
||||
kwarg2=42,
|
||||
):
|
||||
return (arg1, arg2, kwarg1, kwarg2)
|
||||
|
||||
arg_kind = ArgKind()
|
||||
assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
|
||||
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_hub_forward(cls, device):
|
||||
|
Reference in New Issue
Block a user