Add Helion softmax test (#155976)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155976
Approved by: https://github.com/jansel
This commit is contained in:
Oguz Ulgen
2025-06-13 18:37:45 -07:00
committed by PyTorch MergeBot
parent 9338d85d45
commit 92b7ed6d07

View File

@ -41,6 +41,26 @@ class HelionTests(TestCase):
self.assertEqual(out, x + y) self.assertEqual(out, x + y)
self.assertEqual(compiled_out, x + y) self.assertEqual(compiled_out, x + y)
@requires_helion()
def test_softmax_view_reshape(self):
@helion.kernel(config={"block_size": 1})
def softmax(x: torch.Tensor) -> torch.Tensor:
n, _m = x.size()
out = torch.empty_like(x)
for tile_n in hl.tile(n):
values = x[tile_n, :]
amax = torch.amax(values, dim=1).view(tile_n, 1)
exp = torch.exp(values - amax)
sum_exp = torch.reshape(torch.sum(exp, dim=1), [tile_n, 1])
out[tile_n, :] = exp / sum_exp
return out
x = torch.randn([1024, 1024], device=GPU_TYPE, dtype=torch.float16)
result = softmax(x)
self.assertEqual(
result, torch.nn.functional.softmax(x, dim=1), rtol=1e-2, atol=1e-1
)
instantiate_parametrized_tests(HelionTests) instantiate_parametrized_tests(HelionTests)