From 7b8a64557df3218dc004a7d0486ce7a7ea7171d5 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Sun, 7 Sep 2025 23:42:06 -0700 Subject: [PATCH] [inductor] fix 3d tiled online softmax (#162341) The online_softmax_reduce runtime helper previously assumes the input tl.Tensor's are 2d tensors. But with tiled reduction, they can be 3d (y, x, r). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162341 Approved by: https://github.com/jansel, https://github.com/eellison ghstack dependencies: #162311 --- test/inductor/test_online_softmax.py | 13 +++++++++++++ torch/_inductor/runtime/triton_helpers.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 1e94ff1f4987..808757b7e041 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -293,6 +293,19 @@ class TestOnlineSoftmax(TestCase): self.assertTrue(not act.isnan().any()) self.assertTrue(torch.allclose(ref, act)) + @inductor_config.patch(split_reductions=False) + def test_3d_tiled_online_softmax(self): + def f(x, y): + return (x * y).softmax(dim=-1) + + M, N, K = 32, 8, 1024 + + x = torch.randn(K, N, M, device=GPU_TYPE).permute(2, 1, 0) + y = torch.randn(K, M, N, device=GPU_TYPE).permute(1, 2, 0) + + opt_f = torch.compile(f) + torch.testing.assert_close(f(x, y), opt_f(x, y), atol=1e-3, rtol=1e-3) + instantiate_parametrized_tests(TestOnlineSoftmax) diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 1c0285637cf4..e003615b218f 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -176,7 +176,7 @@ def exp(x, use_fast_math: tl.constexpr): @triton.jit def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr): out_max = max2(lhs_max, dim) - out_max_keepdim = out_max[:, None] + out_max_keepdim = tl.expand_dims(out_max, dim) delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim) out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim) return out_max, out_sum