mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[inductor] Refactor reduction type choices into V.choices (#139585)"
This reverts commit 6438c8637a7e28b676a1ccfe942dc37375d0cb14. Reverted https://github.com/pytorch/pytorch/pull/139585 on behalf of https://github.com/kit1980 due to breaking internal builds, see D65800124 ([comment](https://github.com/pytorch/pytorch/pull/139585#issuecomment-2471392822))
This commit is contained in:
@ -1,88 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import config
|
||||
from .runtime.hints import ReductionHint
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
|
||||
|
||||
class InductorChoices:
|
||||
"""
|
||||
This class contains a collection of default heuristics that effect performance of our generated
|
||||
code. We try to not put correctness requirements in this file.
|
||||
|
||||
You can override the choices made here by doing:
|
||||
|
||||
class MyHeuristics(InductorChoices):
|
||||
...
|
||||
|
||||
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool:
|
||||
"""Heuristic to decide if a cooperative reduction should be used."""
|
||||
if config.triton.force_cooperative_reductions:
|
||||
return True
|
||||
if (
|
||||
not config.triton.cooperative_reductions
|
||||
or V.graph.get_current_device_or_throw().type == "cpu"
|
||||
):
|
||||
return False
|
||||
|
||||
xhint = V.graph.sizevars.size_hint(features.numel, fallback=2)
|
||||
if xhint <= 8:
|
||||
threshold = 32768 * xhint
|
||||
elif xhint <= 16:
|
||||
threshold = 2097152
|
||||
else:
|
||||
return False
|
||||
# TODO(jansel): should this default on for dynamic shapes?
|
||||
return V.graph.sizevars.statically_known_geq(
|
||||
features.reduction_numel, threshold
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_use_persistent_reduction(
|
||||
features: SIMDKernelFeatures, cooperative_reduction: bool
|
||||
) -> bool:
|
||||
"""
|
||||
Heuristic to decide if a persistent reduction should be used.
|
||||
"""
|
||||
if not config.triton.persistent_reductions:
|
||||
return False
|
||||
threshold = {
|
||||
ReductionHint.INNER: 1024,
|
||||
}.get(features.get_reduction_hint(), 64)
|
||||
|
||||
if cooperative_reduction:
|
||||
# The RSPLIT of cooperative reductions means each thread block is operating on fewer elements
|
||||
try:
|
||||
threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32)
|
||||
except ValueError:
|
||||
pass # unbacked symint
|
||||
|
||||
# If multi_kernel is enabled, we do more aggressive persistent reduction.
|
||||
# This may result in some persistent reductions slower than the
|
||||
# corresponding non-persistent reductions. MultiKernel will do benchmarking
|
||||
# to pick the faster one.
|
||||
if config.triton.multi_kernel:
|
||||
threshold *= 16
|
||||
return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types]
|
||||
|
||||
@staticmethod
|
||||
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
|
||||
"""
|
||||
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
|
||||
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
|
||||
Strangely this is faster than a [1, RBLOCK] block in some cases.
|
||||
"""
|
||||
return (
|
||||
features.get_reduction_hint() == ReductionHint.INNER
|
||||
and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256)
|
||||
)
|
@ -20,6 +20,7 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -619,20 +620,7 @@ class SIMDKernel(Kernel):
|
||||
except CantSplit:
|
||||
return False
|
||||
|
||||
def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]):
|
||||
groups = [rt.numel for rt in self.range_trees]
|
||||
if not self.inside_reduction:
|
||||
groups[-1] = sympy.S.One
|
||||
|
||||
return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
|
||||
|
||||
@classmethod
|
||||
def map_kernel_groups_to_node_sizes(
|
||||
cls,
|
||||
groups: Sequence[sympy.Expr],
|
||||
lengths: Sequence[Sequence[sympy.Expr]],
|
||||
set_ranges,
|
||||
) -> List[List[sympy.Expr]]:
|
||||
def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
|
||||
"""
|
||||
We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
|
||||
|
||||
@ -645,14 +633,20 @@ class SIMDKernel(Kernel):
|
||||
This function matches and resplits lengths to the groups of
|
||||
this kernel to enable tiled + non-tiled fusions.
|
||||
"""
|
||||
if len(lengths) == len(groups) and all(
|
||||
groups = [rt.numel for rt in self.range_trees]
|
||||
if not self.inside_reduction:
|
||||
groups[-1] = sympy.S.One
|
||||
|
||||
if len(lengths) == len(self.range_trees) and all(
|
||||
V.graph.sizevars.simplify(sympy_product(x) - g) == 0
|
||||
for x, g in zip(lengths, groups)
|
||||
):
|
||||
return set_ranges(*lengths)
|
||||
return self.set_ranges(*lengths)
|
||||
|
||||
new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths)
|
||||
itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))]
|
||||
new_ranges, return_getters_groups = self._split_iteration_ranges(
|
||||
groups, lengths
|
||||
)
|
||||
itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges)))
|
||||
return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
|
||||
|
||||
def is_indirect_indexing(self, index: sympy.Expr):
|
||||
@ -1230,12 +1224,40 @@ class SIMDScheduling(BaseScheduling):
|
||||
return True
|
||||
|
||||
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
|
||||
from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel
|
||||
|
||||
node_schedule = kernel_features.node_schedule
|
||||
tiled_groups = self.select_tiling(
|
||||
node_schedule, kernel_features.numel, kernel_features.reduction_numel
|
||||
)
|
||||
kernels = self.create_kernel_choices(
|
||||
kernel_features, tiled_groups, {"features": kernel_features}
|
||||
|
||||
is_scan = kernel_features.contains_op("scan")
|
||||
is_split_scan = is_scan and any(
|
||||
node.is_split_scan() for node in kernel_features.scheduler_nodes()
|
||||
)
|
||||
kernel_type: Type[SIMDKernel] = self.kernel_type
|
||||
if is_split_scan and issubclass(TritonSplitScanKernel, kernel_type):
|
||||
kernel_type = TritonSplitScanKernel
|
||||
|
||||
kernel_args = tiled_groups
|
||||
kernel_kwargs: Dict[str, Any] = {"features": kernel_features}
|
||||
|
||||
if is_scan:
|
||||
# TODO(jansel): scan does not yet work with cooperative reductions
|
||||
kernel_kwargs["override_cooperative_reduction"] = False
|
||||
|
||||
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
|
||||
# so taking the hit of non-coalesced loads is okay
|
||||
if kernel_features.contains_op("sort"):
|
||||
kernel_kwargs["override_persistent_reduction"] = True
|
||||
|
||||
kernel = kernel_type(
|
||||
*kernel_args,
|
||||
**kernel_kwargs,
|
||||
)
|
||||
|
||||
kernels = self.add_multi_kernel_choices(
|
||||
kernel, kernel_args, kernel_kwargs, node_schedule
|
||||
)
|
||||
for kernel in kernels:
|
||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
||||
@ -1291,15 +1313,10 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
self.scheduler.free_buffers()
|
||||
|
||||
def create_kernel_choices(
|
||||
self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs
|
||||
def add_multi_kernel_choices(
|
||||
self, kernel, kernel_args, kernel_kwargs, node_schedule
|
||||
) -> List[SIMDKernel]:
|
||||
return [
|
||||
self.kernel_type(
|
||||
*kernel_args,
|
||||
**kernel_kwargs,
|
||||
)
|
||||
]
|
||||
return [kernel]
|
||||
|
||||
def codegen_node_schedule_with_kernel(self, node_schedule, kernel):
|
||||
with kernel:
|
||||
|
@ -7,11 +7,11 @@ from typing import Any, Dict, Iterable, List, Type, Union
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.scheduler import SchedulerNode
|
||||
|
||||
from ...utils._ordered_set import OrderedSet
|
||||
from ..dependencies import Dep, MemoryDep
|
||||
from ..runtime.hints import ReductionHint
|
||||
from ..scheduler import SchedulerNode
|
||||
from ..utils import cache_on_self
|
||||
from ..virtualized import V
|
||||
|
||||
@ -73,13 +73,8 @@ class SIMDKernelFeatures:
|
||||
reduction_numel: sympy.Expr = sympy.S.One,
|
||||
):
|
||||
self.node_schedule = node_schedule
|
||||
# numel excludes reduction_numel
|
||||
self.numel: sympy.Expr = V.graph.sizevars.simplify(numel)
|
||||
self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel)
|
||||
|
||||
@cache_on_self
|
||||
def is_reduction(self) -> bool:
|
||||
return self.reduction_numel != 1
|
||||
self.numel = V.graph.sizevars.simplify(numel) # numel excludes reduction_numel
|
||||
self.reduction_numel = V.graph.sizevars.simplify(reduction_numel)
|
||||
|
||||
@cache_on_self
|
||||
def scheduler_nodes(self) -> Iterable[SchedulerNode]:
|
||||
|
@ -50,7 +50,7 @@ from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import TRITON_MAX_BLOCK
|
||||
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
|
||||
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
|
||||
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
|
||||
from ..utils import (
|
||||
@ -1396,9 +1396,28 @@ class TritonKernel(SIMDKernel):
|
||||
return triton_type(dtype)
|
||||
|
||||
def should_use_cooperative_reduction(self) -> bool:
|
||||
return self.inside_reduction and V.choices.should_use_cooperative_reduction(
|
||||
self.features
|
||||
)
|
||||
"""Heuristic to decide self.cooperative_reduction should be used."""
|
||||
if not self.inside_reduction:
|
||||
return False
|
||||
if config.triton.force_cooperative_reductions:
|
||||
return True
|
||||
if (
|
||||
not config.triton.cooperative_reductions
|
||||
or V.graph.get_current_device_or_throw().type == "cpu"
|
||||
):
|
||||
return False
|
||||
|
||||
xnumel, rnumel = self.numels
|
||||
# TODO(jansel): base this on num_bytes_read rather than numel
|
||||
xhint = V.graph.sizevars.size_hint(xnumel, fallback=2)
|
||||
if xhint <= 8:
|
||||
threshold = 32768 * xhint
|
||||
elif xhint <= 16:
|
||||
threshold = 2097152
|
||||
else:
|
||||
return False
|
||||
# TODO(jansel): should this default on for dynamic shapes?
|
||||
return V.graph.sizevars.statically_known_geq(rnumel, threshold)
|
||||
|
||||
def init_cooperative_reduction(self):
|
||||
"""One time setup code for cooperative reductions."""
|
||||
@ -1452,15 +1471,39 @@ class TritonKernel(SIMDKernel):
|
||||
return True
|
||||
|
||||
def should_use_persistent_reduction(self) -> bool:
|
||||
return self.inside_reduction and V.choices.should_use_persistent_reduction(
|
||||
self.features, self.cooperative_reduction
|
||||
)
|
||||
"""
|
||||
Heuristic to set self.persistent_reduction and add guards
|
||||
if needed.
|
||||
"""
|
||||
if not (self.inside_reduction and config.triton.persistent_reductions):
|
||||
return False
|
||||
threshold = {
|
||||
ReductionHint.INNER: 1024,
|
||||
}.get(self.features.get_reduction_hint(), 64)
|
||||
|
||||
if self.cooperative_reduction:
|
||||
# The RSPLIT of cooperative reductions means each thread block is operating on fewer elements
|
||||
xnumel, _ = self.numels
|
||||
try:
|
||||
threshold *= 32 // V.graph.sizevars.size_hint(xnumel)
|
||||
except ValueError:
|
||||
pass # unbacked symint
|
||||
|
||||
# If multi_kernel is enabled, we do more aggressive persistent reduction.
|
||||
# This may result in some persistent reductions slower than the
|
||||
# corresponding non-persistent reductions. MultiKernel will do benchmarking
|
||||
# to pick the faster one.
|
||||
if config.triton.multi_kernel:
|
||||
threshold *= 16
|
||||
last_numel = self.numels[-1]
|
||||
return V.graph.sizevars.statically_known_leq(last_numel, threshold) # type: ignore[arg-types]
|
||||
|
||||
def want_no_x_dim(self):
|
||||
return (
|
||||
self.persistent_reduction
|
||||
and len(self.numels) == 2
|
||||
and V.choices.want_no_x_dim(self.features)
|
||||
and self.features.get_reduction_hint() == ReductionHint.INNER
|
||||
and V.graph.sizevars.statically_known_geq(self.numels[-1], 256) # type: ignore[arg-types]
|
||||
)
|
||||
|
||||
@property
|
||||
@ -3555,36 +3598,12 @@ class TritonScheduling(SIMDScheduling):
|
||||
store_cache()
|
||||
return ms, mod.__file__
|
||||
|
||||
def create_kernel_choices(
|
||||
self, kernel_features, kernel_args, kernel_kwargs
|
||||
) -> List[SIMDKernel]:
|
||||
is_scan = kernel_features.contains_op("scan")
|
||||
is_split_scan = is_scan and any(
|
||||
node.is_split_scan() for node in kernel_features.scheduler_nodes()
|
||||
)
|
||||
kernel_type = TritonKernel
|
||||
if is_split_scan:
|
||||
from .triton_split_scan import TritonSplitScanKernel
|
||||
|
||||
kernel_type = TritonSplitScanKernel
|
||||
|
||||
if is_scan:
|
||||
# TODO(jansel): scan does not yet work with cooperative reductions
|
||||
kernel_kwargs["override_cooperative_reduction"] = False
|
||||
|
||||
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
|
||||
# so taking the hit of non-coalesced loads is okay
|
||||
if kernel_features.contains_op("sort"):
|
||||
kernel_kwargs["override_persistent_reduction"] = True
|
||||
|
||||
kernel = kernel_type(*kernel_args, **kernel_kwargs)
|
||||
return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs)
|
||||
|
||||
def add_multi_kernel_choices(
|
||||
self,
|
||||
kernel: SIMDKernel,
|
||||
kernel_args: List[Any],
|
||||
kernel_kwargs: Dict[str, Any],
|
||||
node_schedule: List[BaseSchedulerNode],
|
||||
) -> List[SIMDKernel]:
|
||||
kernels: List[SIMDKernel] = [kernel]
|
||||
if not config.triton.multi_kernel:
|
||||
|
@ -73,7 +73,6 @@ from .ops_handler import ( # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch._inductor.choices import InductorChoices
|
||||
from torch._inductor.codegen.cpp_utils import LocalBufferContext
|
||||
from torch._inductor.debug import DebugContext
|
||||
from torch._inductor.graph import GraphLowering
|
||||
@ -168,22 +167,6 @@ _local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
|
||||
)
|
||||
|
||||
|
||||
def _choices_default():
|
||||
"""
|
||||
Lazy init the global choices handler
|
||||
|
||||
We virtualize InductorChoices to allow changing inductor heuristics from out of tree.
|
||||
"""
|
||||
from torch._inductor.choices import InductorChoices
|
||||
|
||||
rv = InductorChoices()
|
||||
setattr(threadlocal, _choices._key, rv)
|
||||
return rv
|
||||
|
||||
|
||||
_choices: Virtualized[InductorChoices] = Virtualized("choices", _choices_default)
|
||||
|
||||
|
||||
class OpsValue:
|
||||
"""The return type of most ops calls.
|
||||
|
||||
@ -328,7 +311,6 @@ class _V:
|
||||
get_current_node: Callable[[], Any] = _current_node._get_handler
|
||||
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
|
||||
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
|
||||
set_choices_handler: Callable[[Any], Any] = _choices._set_handler
|
||||
|
||||
@property
|
||||
def ops(self) -> OpsHandler[Any]:
|
||||
@ -375,9 +357,5 @@ class _V:
|
||||
def local_buffer_context(self):
|
||||
return _local_buffer_context._get_handler()
|
||||
|
||||
@property
|
||||
def choices(self) -> InductorChoices:
|
||||
return _choices._get_handler()
|
||||
|
||||
|
||||
V = _V()
|
||||
|
Reference in New Issue
Block a user