Revert "[inductor] Refactor reduction type choices into V.choices (#139585)"

This reverts commit 6438c8637a7e28b676a1ccfe942dc37375d0cb14.

Reverted https://github.com/pytorch/pytorch/pull/139585 on behalf of https://github.com/kit1980 due to breaking internal builds, see D65800124 ([comment](https://github.com/pytorch/pytorch/pull/139585#issuecomment-2471392822))
This commit is contained in:
PyTorch MergeBot
2024-11-12 19:32:14 +00:00
parent c0ddd10f6d
commit 069a71023b
5 changed files with 100 additions and 179 deletions

View File

@ -1,88 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from . import config
from .runtime.hints import ReductionHint
from .virtualized import V
if TYPE_CHECKING:
from .codegen.simd_kernel_features import SIMDKernelFeatures
class InductorChoices:
"""
This class contains a collection of default heuristics that effect performance of our generated
code. We try to not put correctness requirements in this file.
You can override the choices made here by doing:
class MyHeuristics(InductorChoices):
...
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
"""
@staticmethod
def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool:
"""Heuristic to decide if a cooperative reduction should be used."""
if config.triton.force_cooperative_reductions:
return True
if (
not config.triton.cooperative_reductions
or V.graph.get_current_device_or_throw().type == "cpu"
):
return False
xhint = V.graph.sizevars.size_hint(features.numel, fallback=2)
if xhint <= 8:
threshold = 32768 * xhint
elif xhint <= 16:
threshold = 2097152
else:
return False
# TODO(jansel): should this default on for dynamic shapes?
return V.graph.sizevars.statically_known_geq(
features.reduction_numel, threshold
)
@staticmethod
def should_use_persistent_reduction(
features: SIMDKernelFeatures, cooperative_reduction: bool
) -> bool:
"""
Heuristic to decide if a persistent reduction should be used.
"""
if not config.triton.persistent_reductions:
return False
threshold = {
ReductionHint.INNER: 1024,
}.get(features.get_reduction_hint(), 64)
if cooperative_reduction:
# The RSPLIT of cooperative reductions means each thread block is operating on fewer elements
try:
threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32)
except ValueError:
pass # unbacked symint
# If multi_kernel is enabled, we do more aggressive persistent reduction.
# This may result in some persistent reductions slower than the
# corresponding non-persistent reductions. MultiKernel will do benchmarking
# to pick the faster one.
if config.triton.multi_kernel:
threshold *= 16
return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types]
@staticmethod
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
"""
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
Strangely this is faster than a [1, RBLOCK] block in some cases.
"""
return (
features.get_reduction_hint() == ReductionHint.INNER
and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256)
)

View File

@ -20,6 +20,7 @@ from typing import (
Optional,
Sequence,
Tuple,
Type,
Union,
)
@ -619,20 +620,7 @@ class SIMDKernel(Kernel):
except CantSplit:
return False
def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]):
groups = [rt.numel for rt in self.range_trees]
if not self.inside_reduction:
groups[-1] = sympy.S.One
return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
@classmethod
def map_kernel_groups_to_node_sizes(
cls,
groups: Sequence[sympy.Expr],
lengths: Sequence[Sequence[sympy.Expr]],
set_ranges,
) -> List[List[sympy.Expr]]:
def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
"""
We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
@ -645,14 +633,20 @@ class SIMDKernel(Kernel):
This function matches and resplits lengths to the groups of
this kernel to enable tiled + non-tiled fusions.
"""
if len(lengths) == len(groups) and all(
groups = [rt.numel for rt in self.range_trees]
if not self.inside_reduction:
groups[-1] = sympy.S.One
if len(lengths) == len(self.range_trees) and all(
V.graph.sizevars.simplify(sympy_product(x) - g) == 0
for x, g in zip(lengths, groups)
):
return set_ranges(*lengths)
return self.set_ranges(*lengths)
new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths)
itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))]
new_ranges, return_getters_groups = self._split_iteration_ranges(
groups, lengths
)
itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges)))
return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
def is_indirect_indexing(self, index: sympy.Expr):
@ -1230,12 +1224,40 @@ class SIMDScheduling(BaseScheduling):
return True
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel
node_schedule = kernel_features.node_schedule
tiled_groups = self.select_tiling(
node_schedule, kernel_features.numel, kernel_features.reduction_numel
)
kernels = self.create_kernel_choices(
kernel_features, tiled_groups, {"features": kernel_features}
is_scan = kernel_features.contains_op("scan")
is_split_scan = is_scan and any(
node.is_split_scan() for node in kernel_features.scheduler_nodes()
)
kernel_type: Type[SIMDKernel] = self.kernel_type
if is_split_scan and issubclass(TritonSplitScanKernel, kernel_type):
kernel_type = TritonSplitScanKernel
kernel_args = tiled_groups
kernel_kwargs: Dict[str, Any] = {"features": kernel_features}
if is_scan:
# TODO(jansel): scan does not yet work with cooperative reductions
kernel_kwargs["override_cooperative_reduction"] = False
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
# so taking the hit of non-coalesced loads is okay
if kernel_features.contains_op("sort"):
kernel_kwargs["override_persistent_reduction"] = True
kernel = kernel_type(
*kernel_args,
**kernel_kwargs,
)
kernels = self.add_multi_kernel_choices(
kernel, kernel_args, kernel_kwargs, node_schedule
)
for kernel in kernels:
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
@ -1291,15 +1313,10 @@ class SIMDScheduling(BaseScheduling):
self.scheduler.free_buffers()
def create_kernel_choices(
self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs
def add_multi_kernel_choices(
self, kernel, kernel_args, kernel_kwargs, node_schedule
) -> List[SIMDKernel]:
return [
self.kernel_type(
*kernel_args,
**kernel_kwargs,
)
]
return [kernel]
def codegen_node_schedule_with_kernel(self, node_schedule, kernel):
with kernel:

View File

@ -7,11 +7,11 @@ from typing import Any, Dict, Iterable, List, Type, Union
import sympy
import torch
from torch._inductor.scheduler import SchedulerNode
from ...utils._ordered_set import OrderedSet
from ..dependencies import Dep, MemoryDep
from ..runtime.hints import ReductionHint
from ..scheduler import SchedulerNode
from ..utils import cache_on_self
from ..virtualized import V
@ -73,13 +73,8 @@ class SIMDKernelFeatures:
reduction_numel: sympy.Expr = sympy.S.One,
):
self.node_schedule = node_schedule
# numel excludes reduction_numel
self.numel: sympy.Expr = V.graph.sizevars.simplify(numel)
self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel)
@cache_on_self
def is_reduction(self) -> bool:
return self.reduction_numel != 1
self.numel = V.graph.sizevars.simplify(numel) # numel excludes reduction_numel
self.reduction_numel = V.graph.sizevars.simplify(reduction_numel)
@cache_on_self
def scheduler_nodes(self) -> Iterable[SchedulerNode]:

View File

@ -50,7 +50,7 @@ from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir
from ..codecache import code_hash, get_path, PyCodeCache
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import TRITON_MAX_BLOCK
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
from ..utils import (
@ -1396,9 +1396,28 @@ class TritonKernel(SIMDKernel):
return triton_type(dtype)
def should_use_cooperative_reduction(self) -> bool:
return self.inside_reduction and V.choices.should_use_cooperative_reduction(
self.features
)
"""Heuristic to decide self.cooperative_reduction should be used."""
if not self.inside_reduction:
return False
if config.triton.force_cooperative_reductions:
return True
if (
not config.triton.cooperative_reductions
or V.graph.get_current_device_or_throw().type == "cpu"
):
return False
xnumel, rnumel = self.numels
# TODO(jansel): base this on num_bytes_read rather than numel
xhint = V.graph.sizevars.size_hint(xnumel, fallback=2)
if xhint <= 8:
threshold = 32768 * xhint
elif xhint <= 16:
threshold = 2097152
else:
return False
# TODO(jansel): should this default on for dynamic shapes?
return V.graph.sizevars.statically_known_geq(rnumel, threshold)
def init_cooperative_reduction(self):
"""One time setup code for cooperative reductions."""
@ -1452,15 +1471,39 @@ class TritonKernel(SIMDKernel):
return True
def should_use_persistent_reduction(self) -> bool:
return self.inside_reduction and V.choices.should_use_persistent_reduction(
self.features, self.cooperative_reduction
)
"""
Heuristic to set self.persistent_reduction and add guards
if needed.
"""
if not (self.inside_reduction and config.triton.persistent_reductions):
return False
threshold = {
ReductionHint.INNER: 1024,
}.get(self.features.get_reduction_hint(), 64)
if self.cooperative_reduction:
# The RSPLIT of cooperative reductions means each thread block is operating on fewer elements
xnumel, _ = self.numels
try:
threshold *= 32 // V.graph.sizevars.size_hint(xnumel)
except ValueError:
pass # unbacked symint
# If multi_kernel is enabled, we do more aggressive persistent reduction.
# This may result in some persistent reductions slower than the
# corresponding non-persistent reductions. MultiKernel will do benchmarking
# to pick the faster one.
if config.triton.multi_kernel:
threshold *= 16
last_numel = self.numels[-1]
return V.graph.sizevars.statically_known_leq(last_numel, threshold) # type: ignore[arg-types]
def want_no_x_dim(self):
return (
self.persistent_reduction
and len(self.numels) == 2
and V.choices.want_no_x_dim(self.features)
and self.features.get_reduction_hint() == ReductionHint.INNER
and V.graph.sizevars.statically_known_geq(self.numels[-1], 256) # type: ignore[arg-types]
)
@property
@ -3555,36 +3598,12 @@ class TritonScheduling(SIMDScheduling):
store_cache()
return ms, mod.__file__
def create_kernel_choices(
self, kernel_features, kernel_args, kernel_kwargs
) -> List[SIMDKernel]:
is_scan = kernel_features.contains_op("scan")
is_split_scan = is_scan and any(
node.is_split_scan() for node in kernel_features.scheduler_nodes()
)
kernel_type = TritonKernel
if is_split_scan:
from .triton_split_scan import TritonSplitScanKernel
kernel_type = TritonSplitScanKernel
if is_scan:
# TODO(jansel): scan does not yet work with cooperative reductions
kernel_kwargs["override_cooperative_reduction"] = False
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
# so taking the hit of non-coalesced loads is okay
if kernel_features.contains_op("sort"):
kernel_kwargs["override_persistent_reduction"] = True
kernel = kernel_type(*kernel_args, **kernel_kwargs)
return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs)
def add_multi_kernel_choices(
self,
kernel: SIMDKernel,
kernel_args: List[Any],
kernel_kwargs: Dict[str, Any],
node_schedule: List[BaseSchedulerNode],
) -> List[SIMDKernel]:
kernels: List[SIMDKernel] = [kernel]
if not config.triton.multi_kernel:

View File

@ -73,7 +73,6 @@ from .ops_handler import ( # noqa: F401
if TYPE_CHECKING:
import torch
from torch._inductor.choices import InductorChoices
from torch._inductor.codegen.cpp_utils import LocalBufferContext
from torch._inductor.debug import DebugContext
from torch._inductor.graph import GraphLowering
@ -168,22 +167,6 @@ _local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
)
def _choices_default():
"""
Lazy init the global choices handler
We virtualize InductorChoices to allow changing inductor heuristics from out of tree.
"""
from torch._inductor.choices import InductorChoices
rv = InductorChoices()
setattr(threadlocal, _choices._key, rv)
return rv
_choices: Virtualized[InductorChoices] = Virtualized("choices", _choices_default)
class OpsValue:
"""The return type of most ops calls.
@ -328,7 +311,6 @@ class _V:
get_current_node: Callable[[], Any] = _current_node._get_handler
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
set_choices_handler: Callable[[Any], Any] = _choices._set_handler
@property
def ops(self) -> OpsHandler[Any]:
@ -375,9 +357,5 @@ class _V:
def local_buffer_context(self):
return _local_buffer_context._get_handler()
@property
def choices(self) -> InductorChoices:
return _choices._get_handler()
V = _V()