From a714437093ed196eee28f7de454cf4c41badc098 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 2 Sep 2025 17:21:21 -0700 Subject: [PATCH] [ez][inductor] add a few outer dimension reduction cases for LOAF (#162028) For the not able to fuse issue reported here: https://github.com/pytorch/pytorch/issues/93718 , LOAF can fuse the outer dimension softmax into a single kernel and brings 1.87x speedup for the example shape mentioned in the issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162028 Approved by: https://github.com/jansel, https://github.com/eellison --- test/inductor/test_loop_ordering.py | 42 +++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 5773337690b6..a37d01038db6 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -8,6 +8,7 @@ import numpy as np import sympy import torch +import torch.nn.functional as F from torch import nn from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same @@ -473,6 +474,47 @@ class LoopOrderingTest(TestCase): expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output self.assertEqual(expected_numbytes, metrics.num_bytes_accessed) + def test_outer_dimension_softmax(self): + """ + This test repros the not able to fuse problem for outer dimension + softmax reported here: https://github.com/pytorch/pytorch/issues/93718 + + Perf data on h100: + - without loop ordering after fusion 0.564 ms + - with loop ordering after fusion 0.302 ms + This is 1.87x speedup. + + """ + x = torch.randn(32, 2**21, device=GPU_TYPE) + + def f(x): + return F.softmax(x, dim=0) + + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_outer_dimension_sum_fuse_with_pw(self): + """ + Test the fusion of an outer dimension sum with a followed pointwise. + Perf data on h100: + - without loop ordering after fusion 0.436 ms + - with loop ordering after fusion 0.260 ms + This is 1.68x speedup. + """ + x = torch.randn(32, 2**21, device=GPU_TYPE) + + def f(x): + return x.sum(dim=0, keepdim=True) + x + + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + if DO_PERF_TEST: + from triton.testing import do_bench + + optf = torch.compile(f) + print(f"ms={do_bench(lambda: optf(x))}") + # Disable split reduction to make it easier to calculate the expected # number of bytes accessed. In this case, split reduction does not # help perf much.