[ez][inductor] add a few outer dimension reduction cases for LOAF (#162028)

For the not able to fuse issue reported here: https://github.com/pytorch/pytorch/issues/93718 , LOAF can fuse the outer dimension softmax into a single kernel and brings 1.87x speedup for the example shape mentioned in the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162028
Approved by: https://github.com/jansel, https://github.com/eellison
This commit is contained in:
Shunting Zhang
2025-09-02 17:21:21 -07:00
committed by PyTorch MergeBot
parent bffc7dd1f3
commit a714437093

View File

@ -8,6 +8,7 @@ import numpy as np
import sympy
import torch
import torch.nn.functional as F
from torch import nn
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
@ -473,6 +474,47 @@ class LoopOrderingTest(TestCase):
expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output
self.assertEqual(expected_numbytes, metrics.num_bytes_accessed)
def test_outer_dimension_softmax(self):
"""
This test repros the not able to fuse problem for outer dimension
softmax reported here: https://github.com/pytorch/pytorch/issues/93718
Perf data on h100:
- without loop ordering after fusion 0.564 ms
- with loop ordering after fusion 0.302 ms
This is 1.87x speedup.
"""
x = torch.randn(32, 2**21, device=GPU_TYPE)
def f(x):
return F.softmax(x, dim=0)
self.do_acc_test(f, x)
self.assertEqual(1, metrics.generated_kernel_count)
def test_outer_dimension_sum_fuse_with_pw(self):
"""
Test the fusion of an outer dimension sum with a followed pointwise.
Perf data on h100:
- without loop ordering after fusion 0.436 ms
- with loop ordering after fusion 0.260 ms
This is 1.68x speedup.
"""
x = torch.randn(32, 2**21, device=GPU_TYPE)
def f(x):
return x.sum(dim=0, keepdim=True) + x
self.do_acc_test(f, x)
self.assertEqual(1, metrics.generated_kernel_count)
if DO_PERF_TEST:
from triton.testing import do_bench
optf = torch.compile(f)
print(f"ms={do_bench(lambda: optf(x))}")
# Disable split reduction to make it easier to calculate the expected
# number of bytes accessed. In this case, split reduction does not
# help perf much.