mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-11-04 22:24:32 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			v0.10.1
			...
			fixup-arg-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b287111821 | 
@ -114,16 +114,16 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
 | 
			
		||||
 | 
			
		||||
    cached_forward: Dict[LayerRepository, Callable] = {}
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, **args):
 | 
			
		||||
    def forward(self, x, *args, **kwargs):
 | 
			
		||||
        kernel = _KERNEL_MAPPING.get().get(layer_name)
 | 
			
		||||
        if kernel is None:
 | 
			
		||||
            if not use_fallback:
 | 
			
		||||
                raise ValueError(f"No layer mapping for `{layer_name}`")
 | 
			
		||||
            return fallback_forward(self, x, **args)
 | 
			
		||||
            return fallback_forward(self, x, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        device = getattr(x, "device", None)
 | 
			
		||||
        if device is None:
 | 
			
		||||
            return fallback_forward(self, x, **args)
 | 
			
		||||
            return fallback_forward(self, x, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        repo = kernel.get(Device(type=device.type))
 | 
			
		||||
        if repo is None:
 | 
			
		||||
@ -131,12 +131,12 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"No layer mapping for `{layer_name}` with device type `{device.type}`"
 | 
			
		||||
                )
 | 
			
		||||
            return fallback_forward(self, x, **args)
 | 
			
		||||
            return fallback_forward(self, x, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Short-circuit if we already loaded the layer.
 | 
			
		||||
        layer_forward = cached_forward.get(repo, None)
 | 
			
		||||
        if layer_forward is not None:
 | 
			
		||||
            return layer_forward(self, x, **args)
 | 
			
		||||
            return layer_forward(self, x, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        layer = _get_kernel_layer(
 | 
			
		||||
            repo_id=repo.repo_id,
 | 
			
		||||
@ -155,7 +155,7 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
 | 
			
		||||
        layer_forward = layer.forward
 | 
			
		||||
        cached_forward[repo] = layer_forward
 | 
			
		||||
 | 
			
		||||
        return layer_forward(self, x, **args)
 | 
			
		||||
        return layer_forward(self, x, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    cls.forward = forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,24 @@ class SiluAndMulStringDevice(SiluAndMul):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_arg_kinds():
 | 
			
		||||
    @use_kernel_forward_from_hub("ArgKind")
 | 
			
		||||
    class ArgKind(nn.Module):
 | 
			
		||||
        def forward(
 | 
			
		||||
            self,
 | 
			
		||||
            arg1,
 | 
			
		||||
            arg2,
 | 
			
		||||
            *,
 | 
			
		||||
            kwarg1,
 | 
			
		||||
            kwarg2=42,
 | 
			
		||||
        ):
 | 
			
		||||
            return (arg1, arg2, kwarg1, kwarg2)
 | 
			
		||||
 | 
			
		||||
    arg_kind = ArgKind()
 | 
			
		||||
    assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
 | 
			
		||||
    assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
 | 
			
		||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
 | 
			
		||||
def test_hub_forward(cls, device):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user