mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] fix 3d tiled online softmax (#162341)
The online_softmax_reduce runtime helper previously assumes the input tl.Tensor's are 2d tensors. But with tiled reduction, they can be 3d (y, x, r). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162341 Approved by: https://github.com/jansel, https://github.com/eellison ghstack dependencies: #162311
This commit is contained in:
committed by
PyTorch MergeBot
parent
d8b6622bb6
commit
7b8a64557d
@ -293,6 +293,19 @@ class TestOnlineSoftmax(TestCase):
|
||||
self.assertTrue(not act.isnan().any())
|
||||
self.assertTrue(torch.allclose(ref, act))
|
||||
|
||||
@inductor_config.patch(split_reductions=False)
|
||||
def test_3d_tiled_online_softmax(self):
|
||||
def f(x, y):
|
||||
return (x * y).softmax(dim=-1)
|
||||
|
||||
M, N, K = 32, 8, 1024
|
||||
|
||||
x = torch.randn(K, N, M, device=GPU_TYPE).permute(2, 1, 0)
|
||||
y = torch.randn(K, M, N, device=GPU_TYPE).permute(1, 2, 0)
|
||||
|
||||
opt_f = torch.compile(f)
|
||||
torch.testing.assert_close(f(x, y), opt_f(x, y), atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOnlineSoftmax)
|
||||
|
||||
|
@ -176,7 +176,7 @@ def exp(x, use_fast_math: tl.constexpr):
|
||||
@triton.jit
|
||||
def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr):
|
||||
out_max = max2(lhs_max, dim)
|
||||
out_max_keepdim = out_max[:, None]
|
||||
out_max_keepdim = tl.expand_dims(out_max, dim)
|
||||
delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim)
|
||||
out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim)
|
||||
return out_max, out_sum
|
||||
|
Reference in New Issue
Block a user