Compare commits

...

6 Commits

Author SHA1 Message Date
31148514ae Update on "[inductor] split reduction even if all reads are broadcasted"
With split reduction we can speedup the following (extreme) kernel by 48x

```
# 56ms -> 1.163ms

import torch
from triton.testing import do_bench

def f(x):
    return x.sum(dim=(0, 1))

x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)

torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```

Not confident if this change will break things. Let's wait for CI



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-14 17:41:48 -08:00
e63ef129bf Update base for Update on "[inductor] split reduction even if all reads are broadcasted"
With split reduction we can speedup the following (extreme) kernel by 48x

```
# 56ms -> 1.163ms

import torch
from triton.testing import do_bench

def f(x):
    return x.sum(dim=(0, 1))

x = torch.randn(100000000, 1, 2, device="cuda").expand(-1, 2, -1)
opt_f = torch.compile(f)
ref = f(x)
act = opt_f(x)

torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
ms = do_bench(lambda: opt_f(x))
print(f"ms={ms:.3f}")
```

Not confident if this change will break things. Let's wait for CI



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-14 17:41:48 -08:00
ecc0f4be7f [inductor] split reduction even if all reads are broadcasted
[ghstack-poisoned]
2025-11-14 17:16:26 -08:00
72dd1930f3 Update on "[inductor] fix the decision of inner reduction"
Inductor may treat an outer reduction as inner reduction when the reduction ranges contains a 1. This cause some weird issue that we skip fusing with mix order reduction. While I'm still debugging why that happens, I think we should fix the decision here anyways


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-13 12:20:41 -08:00
68a44e25b6 Update on "[inductor] fix the decision of inner reduction"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-12 17:33:36 -08:00
0d6bd26502 [inductor] fix the decision of inner reduction
[ghstack-poisoned]
2025-11-12 17:25:47 -08:00
3 changed files with 57 additions and 16 deletions

View File

@ -270,11 +270,20 @@ class MixOrderReductionTest(TestBase):
],
)
@parametrize("split_reductions", (False, True))
@parametrize("shape", ((32768, 2048), (32768, 768), (32768 + 1023, 768)))
@parametrize(
"shape", ((1000000, 256), (32768, 2048), (32768, 768), (32768 + 1023, 768))
)
@parametrize("max_autotune", (False, True))
@parametrize("initial_xblock", (1, 2))
@parametrize("add_1dim", (False, True))
def test_rms_norm_bwd(
self, wdtype, split_reductions, shape, max_autotune, initial_xblock
self,
wdtype,
split_reductions,
shape,
max_autotune,
initial_xblock,
add_1dim,
):
# max_autotune can be slow and cost resource, trim down the tests
# for max autotune
@ -287,6 +296,9 @@ class MixOrderReductionTest(TestBase):
):
self.skipTest("Skip non-critical tests to save resources.")
if shape != (1000000, 256) and add_1dim:
self.skipTest("Skip non-critical tests to save resources.")
def f(x, w, eps):
orig_dtype = x.dtype
@ -307,6 +319,9 @@ class MixOrderReductionTest(TestBase):
# M, N = 1152 * 500, 384
M, N = shape
x = torch.randn(M, N, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True)
if add_1dim:
x = x[:, None, :]
w = torch.randn(N, dtype=wdtype, device=GPU_TYPE, requires_grad=True)
dy = torch.randn_like(x)
eps = 1e-5

View File

@ -14626,6 +14626,36 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue(same((ref, ref_grad), (act, act_grad), tol=1e-3))
@skipIfMPS
def test_inner_reduction_detection(self):
if self.device == "cpu":
self.skipTest("Skip for CPU device")
x = torch.randn(100000, 1, 256, device=self.device)
@torch.compile
def f(x):
return x.sum(dim=(0, 1))
code = run_and_get_triton_code(f, x)
self.assertTrue("ReductionHint.OUTER" in code)
self.assertFalse("ReductionHint.INNER" in code)
@skipIfMPS
def test_broadcasted_inner_reduction_detection(self):
if self.device == "cpu":
self.skipTest("Skip for CPU device")
x = torch.randn(2000000, 1, 2, device=self.device).expand(-1, 2, -1)
@torch.compile
def f(x):
return x.sum(dim=(0, 1))
code = run_and_get_triton_code(f, x)
self.assertTrue("ReductionHint.OUTER" in code)
self.assertFalse("ReductionHint.INNER" in code)
@skip_if_halide
@requires_cuda_and_triton
@skip_if_cpp_wrapper("skip cpp wrapper")

View File

@ -1399,22 +1399,16 @@ class Reduction(Loops):
# TODO this will fail for something like ((1, N) * (N, 1)).sum()
# this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
assert read_writes.range_vars is not None
range_vars = [
r
for r in read_writes.range_vars
if isinstance(r, Expr) and not isinstance(r, sympy.Number)
]
indices = []
changed = False
for md in sorted(read_writes.reads, key=lambda x: x.name):
if all(r in md.index.free_symbols for r in range_vars):
indices.append(md.index)
if md.name in V.graph.name_to_buffer:
buf = V.graph.name_to_buffer[md.name]
original_stride = getattr(buf.layout, "stride", None)
buf.decide_layout()
if getattr(buf.layout, "stride", None) != original_stride:
changed = True
indices.append(md.index)
if md.name in V.graph.name_to_buffer:
buf = V.graph.name_to_buffer[md.name]
original_stride = getattr(buf.layout, "stride", None)
buf.decide_layout()
if getattr(buf.layout, "stride", None) != original_stride:
changed = True
return indices, changed
indices, changed = get_read_indices(r)
@ -1435,7 +1429,9 @@ class Reduction(Loops):
strides = V.graph.sizevars.stride_hints(
j, reduction_vars, list(ranges1.keys())
)
outer = all(s > 1 for s in strides)
# A 0 stride does not make a reduction contiguous.
# This can happen when the reduction ranges contains a 1.
outer = all(s == 0 or s > 1 for s in strides)
if outer:
num_outer += 1
else: