mirror of
https://github.com/huggingface/kernels.git
synced 2025-11-06 23:24:31 +08:00
Compare commits
1 Commits
v0.11.0
...
test-torch
| Author | SHA1 | Date | |
|---|---|---|---|
| bc1564fa34 |
@ -21,11 +21,16 @@ def device():
|
||||
return "cuda"
|
||||
|
||||
|
||||
def test_gelu_fast(kernel, device):
|
||||
@pytest.mark.parametrize("torch_compile", [False, True])
|
||||
def test_gelu_fast(kernel, device, torch_compile):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
kernel.gelu_fast(y, x)
|
||||
op = kernel.gelu_fast
|
||||
if torch_compile:
|
||||
op = torch.compile(op)
|
||||
|
||||
op(y, x)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
|
||||
@ -73,7 +73,8 @@ def test_arg_kinds():
|
||||
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_hub_forward(cls, device):
|
||||
@pytest.mark.parametrize("torch_compile", [False, True])
|
||||
def test_hub_forward(cls, device, torch_compile):
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
silu_and_mul = SiluAndMul()
|
||||
@ -81,6 +82,8 @@ def test_hub_forward(cls, device):
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = cls()
|
||||
if torch_compile:
|
||||
silu_and_mul_with_kernel = torch.compile(silu_and_mul_with_kernel)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
|
||||
Reference in New Issue
Block a user