mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 09:04:32 +08:00
write the body only once
This commit is contained in:
@ -2523,24 +2523,23 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
innermost_tree = self.range_trees[-1]
|
||||
if self.inside_reduction and innermost_tree.is_loop:
|
||||
# Write the loop headers.
|
||||
loop_trees = [tree for tree in self.range_trees if tree.is_loop]
|
||||
for level, tree in enumerate(loop_trees):
|
||||
# Write the loop header.
|
||||
with self.body.indent(offset=level):
|
||||
prefix = tree.prefix
|
||||
self.body.writeline(
|
||||
f"for {prefix}offset in range(0, {prefix}numel, {prefix.upper()}BLOCK):"
|
||||
)
|
||||
|
||||
# Write the loop body.
|
||||
with self.body.indent(offset=level + 1):
|
||||
self.iteration_ranges_codegen_header(tree, self.body)
|
||||
if tree == innermost_tree:
|
||||
# The innermost loop performs the reduction.
|
||||
self.body.splice(self.indexing_code)
|
||||
self.body.splice(self.loads)
|
||||
self.body.splice(self.compute)
|
||||
self.body.splice(self.stores)
|
||||
|
||||
# The innermost loop performs the reduction.
|
||||
with self.body.indent(offset=len(loop_trees)):
|
||||
self.body.splice(self.indexing_code)
|
||||
self.body.splice(self.loads)
|
||||
self.body.splice(self.compute)
|
||||
self.body.splice(self.stores)
|
||||
|
||||
# Write loop suffixes.
|
||||
for level, tree in sorted(enumerate(loop_trees), reverse=True):
|
||||
|
||||
Reference in New Issue
Block a user