write the body only once

This commit is contained in:
Blaine Burton Rister
2024-09-26 15:07:05 -07:00
parent e7356929a0
commit b0d5c04807

View File

@ -2523,24 +2523,23 @@ class TritonKernel(SIMDKernel):
innermost_tree = self.range_trees[-1]
if self.inside_reduction and innermost_tree.is_loop:
# Write the loop headers.
loop_trees = [tree for tree in self.range_trees if tree.is_loop]
for level, tree in enumerate(loop_trees):
# Write the loop header.
with self.body.indent(offset=level):
prefix = tree.prefix
self.body.writeline(
f"for {prefix}offset in range(0, {prefix}numel, {prefix.upper()}BLOCK):"
)
# Write the loop body.
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)
if tree == innermost_tree:
# The innermost loop performs the reduction.
self.body.splice(self.indexing_code)
self.body.splice(self.loads)
self.body.splice(self.compute)
self.body.splice(self.stores)
# The innermost loop performs the reduction.
with self.body.indent(offset=len(loop_trees)):
self.body.splice(self.indexing_code)
self.body.splice(self.loads)
self.body.splice(self.compute)
self.body.splice(self.stores)
# Write loop suffixes.
for level, tree in sorted(enumerate(loop_trees), reverse=True):