mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Helion softmax test (#155976)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155976 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9338d85d45
commit
92b7ed6d07
@ -41,6 +41,26 @@ class HelionTests(TestCase):
|
||||
self.assertEqual(out, x + y)
|
||||
self.assertEqual(compiled_out, x + y)
|
||||
|
||||
@requires_helion()
|
||||
def test_softmax_view_reshape(self):
|
||||
@helion.kernel(config={"block_size": 1})
|
||||
def softmax(x: torch.Tensor) -> torch.Tensor:
|
||||
n, _m = x.size()
|
||||
out = torch.empty_like(x)
|
||||
for tile_n in hl.tile(n):
|
||||
values = x[tile_n, :]
|
||||
amax = torch.amax(values, dim=1).view(tile_n, 1)
|
||||
exp = torch.exp(values - amax)
|
||||
sum_exp = torch.reshape(torch.sum(exp, dim=1), [tile_n, 1])
|
||||
out[tile_n, :] = exp / sum_exp
|
||||
return out
|
||||
|
||||
x = torch.randn([1024, 1024], device=GPU_TYPE, dtype=torch.float16)
|
||||
result = softmax(x)
|
||||
self.assertEqual(
|
||||
result, torch.nn.functional.softmax(x, dim=1), rtol=1e-2, atol=1e-1
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(HelionTests)
|
||||
|
||||
|
Reference in New Issue
Block a user