mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ez][inductor] add a few outer dimension reduction cases for LOAF (#162028)
For the not able to fuse issue reported here: https://github.com/pytorch/pytorch/issues/93718 , LOAF can fuse the outer dimension softmax into a single kernel and brings 1.87x speedup for the example shape mentioned in the issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162028 Approved by: https://github.com/jansel, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
bffc7dd1f3
commit
a714437093
@ -8,6 +8,7 @@ import numpy as np
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import same
|
||||
@ -473,6 +474,47 @@ class LoopOrderingTest(TestCase):
|
||||
expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output
|
||||
self.assertEqual(expected_numbytes, metrics.num_bytes_accessed)
|
||||
|
||||
def test_outer_dimension_softmax(self):
|
||||
"""
|
||||
This test repros the not able to fuse problem for outer dimension
|
||||
softmax reported here: https://github.com/pytorch/pytorch/issues/93718
|
||||
|
||||
Perf data on h100:
|
||||
- without loop ordering after fusion 0.564 ms
|
||||
- with loop ordering after fusion 0.302 ms
|
||||
This is 1.87x speedup.
|
||||
|
||||
"""
|
||||
x = torch.randn(32, 2**21, device=GPU_TYPE)
|
||||
|
||||
def f(x):
|
||||
return F.softmax(x, dim=0)
|
||||
|
||||
self.do_acc_test(f, x)
|
||||
self.assertEqual(1, metrics.generated_kernel_count)
|
||||
|
||||
def test_outer_dimension_sum_fuse_with_pw(self):
|
||||
"""
|
||||
Test the fusion of an outer dimension sum with a followed pointwise.
|
||||
Perf data on h100:
|
||||
- without loop ordering after fusion 0.436 ms
|
||||
- with loop ordering after fusion 0.260 ms
|
||||
This is 1.68x speedup.
|
||||
"""
|
||||
x = torch.randn(32, 2**21, device=GPU_TYPE)
|
||||
|
||||
def f(x):
|
||||
return x.sum(dim=0, keepdim=True) + x
|
||||
|
||||
self.do_acc_test(f, x)
|
||||
self.assertEqual(1, metrics.generated_kernel_count)
|
||||
|
||||
if DO_PERF_TEST:
|
||||
from triton.testing import do_bench
|
||||
|
||||
optf = torch.compile(f)
|
||||
print(f"ms={do_bench(lambda: optf(x))}")
|
||||
|
||||
# Disable split reduction to make it easier to calculate the expected
|
||||
# number of bytes accessed. In this case, split reduction does not
|
||||
# help perf much.
|
||||
|
Reference in New Issue
Block a user