mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Tiling rewrite pt1] Normalize reads and writes to common iter space (#153723)
In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153723 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
635b73e697
commit
00dfd3891e
@ -705,6 +705,26 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
|
||||
return new_ranges, return_getters_groups
|
||||
|
||||
@classmethod
|
||||
def prepare_split_iteration_lengths(
|
||||
cls,
|
||||
groups: Iterable[sympy.Expr],
|
||||
lengths: Sequence[Sequence[sympy.Expr]],
|
||||
reduction_numel: sympy.Expr = sympy.S.One,
|
||||
) -> Sequence[Sequence[sympy.Expr]]:
|
||||
"Fill in the reduction numel of lengths if missing"
|
||||
sizevars = V.graph.sizevars
|
||||
if len(lengths[1]) == 0 and (
|
||||
not sizevars.statically_known_equals(reduction_numel, sympy.S.One)
|
||||
and sizevars.statically_known_equals(
|
||||
sympy_product(groups),
|
||||
sympy_product(lengths[0]) * reduction_numel,
|
||||
)
|
||||
):
|
||||
return (lengths[0], [reduction_numel])
|
||||
|
||||
return lengths
|
||||
|
||||
@classmethod
|
||||
def is_compatible(
|
||||
cls,
|
||||
@ -712,15 +732,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
lengths: Sequence[Sequence[sympy.Expr]],
|
||||
reduction_numel: sympy.Expr = sympy.S.One,
|
||||
) -> bool:
|
||||
# Fill in the reduction numel, in case the node is missing it.
|
||||
sizevars = V.graph.sizevars
|
||||
if len(lengths[1]) == 0 and (
|
||||
sizevars.statically_known_equals(
|
||||
sympy_product(groups),
|
||||
sympy_product(lengths[0]) * reduction_numel,
|
||||
)
|
||||
):
|
||||
lengths = (lengths[0], [reduction_numel])
|
||||
lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel)
|
||||
|
||||
try:
|
||||
cls._split_iteration_ranges(groups, lengths)
|
||||
|
Reference in New Issue
Block a user