update combo kernel logic that looks for reduction trees

This commit is contained in:
Blaine Burton Rister
2024-10-14 17:52:40 -07:00
parent 3107b59ce4
commit 02365a50df

View File

@ -789,7 +789,7 @@ class ComboKernel(Kernel):
expr = V.graph.wrapper_code.generate_numel_expr(
name, tree, suffix=str(num)
)
if tree.prefix != "r":
if not tree.is_reduction:
assert isinstance(
grid[i][num], str
), f"Grid {grid[i][num]} should be a dynamic shape."
@ -799,7 +799,7 @@ class ComboKernel(Kernel):
), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
grid[i][num] = -expr if numel_sign == "-" else expr
if tree.prefix != "r" or sub_kernel.inside_reduction:
if not tree.is_reduction or sub_kernel.inside_reduction:
call_args.append(expr)
arg_types.append(type(expr))
@ -812,7 +812,7 @@ class ComboKernel(Kernel):
if numel_name not in self.dynamic_shape_args:
continue
expr = V.graph.sizevars.size_hint(tree.numel)
if tree.prefix != "r":
if not tree.is_reduction:
assert isinstance(
grid[i][num], str
), f"Grid {grid[i][num]} should be a dynamic shape."
@ -821,7 +821,7 @@ class ComboKernel(Kernel):
grid[i][num] == numel_sign + numel_name
), f"grid mismatch: {grid[i][num]} vs {numel_name}"
grid[i][num] = -expr if numel_sign == "-" else expr
if tree.prefix != "r" or sub_kernel.inside_reduction:
if not tree.is_reduction or sub_kernel.inside_reduction:
extra_args.append(expr)
def codegen_kernel(self, name: Optional[str] = None) -> str: