mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Update on "[inductor] track reduction before splitting"
Keep tracking of the reduction before splitting. In the mix-order reduction context, if one of the reduction is split, it makes it much harder to fuse with the other reduction. Tracking the metadata of the reduction before splitting to make the fusion possible. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This commit is contained in:
@ -131,6 +131,20 @@ class MixOrderReductionTest(TestBase):
|
||||
metrics.codegen_mix_order_reduction,
|
||||
)
|
||||
|
||||
@inductor_config.patch(split_reductions=False)
|
||||
def test_non_contiguous_input(self):
|
||||
def f(x):
|
||||
return x.sum(dim=-1), x.sum(dim=[0, 1])
|
||||
|
||||
x = torch.randn(1024, 32, 768, dtype=torch.float, device=GPU_TYPE).permute(
|
||||
1, 0, 2
|
||||
)
|
||||
self.check_numeric(f, (x,))
|
||||
self.assertEqual(
|
||||
inductor_config.triton.mix_order_reduction,
|
||||
metrics.codegen_mix_order_reduction,
|
||||
)
|
||||
|
||||
@inductor_config.patch(split_reductions=False)
|
||||
def test_multi_workspace_allocation(self):
|
||||
def f(x, y):
|
||||
|
||||
@ -4547,14 +4547,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
)
|
||||
accumname2var[name] = self.cse.namedvar(name, dtype=torch.float)
|
||||
self.body.writeline("split_size = min(RSPLIT_SIZE, xnumel - xoffset)")
|
||||
self.body.writeline("for suboff in range(0, split_size, XBLOCK):")
|
||||
self.body.writeline("for _ in range(0, split_size, XBLOCK):")
|
||||
with self.body.indent(offset=1):
|
||||
self.body.splice(self.indexing_code)
|
||||
self.body.writelines(
|
||||
[
|
||||
"x0 = xindex + suboff",
|
||||
"xindex += XBLOCK",
|
||||
# TODO we force XBLOCK==1 for now so there is
|
||||
# no need to update the xmask
|
||||
]
|
||||
)
|
||||
self.body.splice(self.indexing_code)
|
||||
self.body.splice(self.loads)
|
||||
self.body.splice(self.compute)
|
||||
self.body.splice(self.stores)
|
||||
@ -5345,7 +5347,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
|
||||
line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
|
||||
if entry.root.is_loop:
|
||||
|
||||
# mix order reduction introduces an extra loop across the x
|
||||
# dimension
|
||||
if entry.root.is_loop or (self.mix_order_reduction and entry.prefix == "x"):
|
||||
self.indexing_code.writeline(line)
|
||||
else:
|
||||
# lift non-reduction stores outside loop
|
||||
|
||||
@ -243,6 +243,7 @@ class MixOrderReduction:
|
||||
def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
|
||||
if not config.triton.mix_order_reduction:
|
||||
return False
|
||||
|
||||
if not node1.is_gpu() or not node2.is_gpu():
|
||||
return False
|
||||
if node1.get_device().type != "cuda" or config.cuda_backend != "triton": # type: ignore[union-attr]
|
||||
@ -4516,6 +4517,15 @@ class Scheduler:
|
||||
if node1 is node2:
|
||||
return False
|
||||
|
||||
# We don't further fuse with FusedMixOrderReductions for now.
|
||||
# It's not a big deal since the score for fusion with
|
||||
# mix order reduction is low. When we do this kind of fusion,
|
||||
# the participants should have already been well fused.
|
||||
if isinstance(node1, FusedMixOrderReductions) or isinstance(
|
||||
node2, FusedMixOrderReductions
|
||||
):
|
||||
return False
|
||||
|
||||
why = WhyNoFuse(node1, node2)
|
||||
|
||||
if node1.is_template() and self.get_backend(
|
||||
|
||||
Reference in New Issue
Block a user