[Tiling rewrite pt1] Normalize reads and writes to common iter space (#153723)

In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153723
Approved by: https://github.com/jansel
This commit is contained in:
eellison
2025-06-02 16:44:56 -07:00
committed by PyTorch MergeBot
parent 635b73e697
commit 00dfd3891e
8 changed files with 659 additions and 13 deletions

View File

@ -705,6 +705,26 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
return new_ranges, return_getters_groups
@classmethod
def prepare_split_iteration_lengths(
cls,
groups: Iterable[sympy.Expr],
lengths: Sequence[Sequence[sympy.Expr]],
reduction_numel: sympy.Expr = sympy.S.One,
) -> Sequence[Sequence[sympy.Expr]]:
"Fill in the reduction numel of lengths if missing"
sizevars = V.graph.sizevars
if len(lengths[1]) == 0 and (
not sizevars.statically_known_equals(reduction_numel, sympy.S.One)
and sizevars.statically_known_equals(
sympy_product(groups),
sympy_product(lengths[0]) * reduction_numel,
)
):
return (lengths[0], [reduction_numel])
return lengths
@classmethod
def is_compatible(
cls,
@ -712,15 +732,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
lengths: Sequence[Sequence[sympy.Expr]],
reduction_numel: sympy.Expr = sympy.S.One,
) -> bool:
# Fill in the reduction numel, in case the node is missing it.
sizevars = V.graph.sizevars
if len(lengths[1]) == 0 and (
sizevars.statically_known_equals(
sympy_product(groups),
sympy_product(lengths[0]) * reduction_numel,
)
):
lengths = (lengths[0], [reduction_numel])
lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel)
try:
cls._split_iteration_ranges(groups, lengths)