mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
update combo kernel logic that looks for reduction trees
This commit is contained in:
@ -789,7 +789,7 @@ class ComboKernel(Kernel):
|
||||
expr = V.graph.wrapper_code.generate_numel_expr(
|
||||
name, tree, suffix=str(num)
|
||||
)
|
||||
if tree.prefix != "r":
|
||||
if not tree.is_reduction:
|
||||
assert isinstance(
|
||||
grid[i][num], str
|
||||
), f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
@ -799,7 +799,7 @@ class ComboKernel(Kernel):
|
||||
), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
|
||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||
|
||||
if tree.prefix != "r" or sub_kernel.inside_reduction:
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
call_args.append(expr)
|
||||
arg_types.append(type(expr))
|
||||
|
||||
@ -812,7 +812,7 @@ class ComboKernel(Kernel):
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
expr = V.graph.sizevars.size_hint(tree.numel)
|
||||
if tree.prefix != "r":
|
||||
if not tree.is_reduction:
|
||||
assert isinstance(
|
||||
grid[i][num], str
|
||||
), f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
@ -821,7 +821,7 @@ class ComboKernel(Kernel):
|
||||
grid[i][num] == numel_sign + numel_name
|
||||
), f"grid mismatch: {grid[i][num]} vs {numel_name}"
|
||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||
if tree.prefix != "r" or sub_kernel.inside_reduction:
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
extra_args.append(expr)
|
||||
|
||||
def codegen_kernel(self, name: Optional[str] = None) -> str:
|
||||
|
||||
Reference in New Issue
Block a user