Compare commits

...

1 Commits

2 changed files with 11 additions and 3 deletions

View File

@ -21,11 +21,16 @@ def device():
return "cuda"
def test_gelu_fast(kernel, device):
@pytest.mark.parametrize("torch_compile", [False, True])
def test_gelu_fast(kernel, device, torch_compile):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
op = kernel.gelu_fast
if torch_compile:
op = torch.compile(op)
op(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],

View File

@ -73,7 +73,8 @@ def test_arg_kinds():
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
@pytest.mark.parametrize("torch_compile", [False, True])
def test_hub_forward(cls, device, torch_compile):
torch.random.manual_seed(0)
silu_and_mul = SiluAndMul()
@ -81,6 +82,8 @@ def test_hub_forward(cls, device):
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
if torch_compile:
silu_and_mul_with_kernel = torch.compile(silu_and_mul_with_kernel)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)