[inductor] Refactor runtime files into torch._inductor.runtime (part 2) (#124553)

I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124553
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552
This commit is contained in:
Jason Ansel
2024-04-21 11:09:44 -07:00
committed by PyTorch MergeBot
parent 480585fd2b
commit bb8815bc31
9 changed files with 103 additions and 116 deletions

View File

@ -381,12 +381,8 @@ class CudaReproTests(TestCase):
https://github.com/pytorch/torchdynamo/issues/1670
"""
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
from torch._inductor.runtime.triton_heuristics import (
CachingAutotuner,
grid,
HeuristicType,
)
from torch._inductor.utils import instance_descriptor
from torch._inductor.runtime.hints import HeuristicType, instance_descriptor
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid
def autotune(configs, meta):
def decorator(fn):

View File

@ -34,7 +34,7 @@ import torch.utils._pytree as pytree
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
from torch._inductor.runtime.triton_heuristics import AutotuneHint
from torch._inductor.runtime.hints import AutotuneHint
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import ValueRanges
@ -44,8 +44,9 @@ from ..._dynamo.utils import counters
from .. import config, ir, scheduler
from ..codecache import code_hash, get_path, PyCodeCache
from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
from ..ir import IRNode, ReductionHint, TritonTemplateBuffer
from ..ir import IRNode, TritonTemplateBuffer
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..runtime.hints import ReductionHint
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
from ..utils import (
cache_on_self,
@ -120,15 +121,9 @@ def gen_common_triton_imports():
imports.splice(
"""
from torch._inductor.runtime import (
triton_helpers,
triton_heuristics,
libdevice,
tl_math,
AutotuneHint,
)
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
"""
)
return imports.getvalue()

View File

@ -2,7 +2,8 @@ import functools
from typing import Optional, Set
from torch._inductor import config, ir
import torch._inductor.runtime.hints
from torch._inductor import config
from torch._inductor.codegen.triton import (
IterationRangesRoot,
@ -36,7 +37,7 @@ class TritonSplitScanKernel(TritonKernel):
*groups,
index_dtype: str,
mutations: Optional[Set[str]] = None,
reduction_hint=ir.ReductionHint.DEFAULT,
reduction_hint=torch._inductor.runtime.hints.ReductionHint.DEFAULT,
min_elem_per_thread=0,
):
super().__init__(

View File

@ -5,7 +5,8 @@ import sympy
import torch
from .. import config
from ..utils import _type_of, instance_descriptor
from ..runtime.hints import instance_descriptor
from ..utils import _type_of
from ..virtualized import V
from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg

View File

@ -8,7 +8,6 @@ import re
import textwrap
import traceback
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import (
Any,
@ -61,6 +60,7 @@ from .dependencies import (
var_builder,
)
from .ops_handler import OpCounterCSE
from .runtime.hints import ReductionHint
from .utils import (
argsort,
cache_on_self,
@ -533,18 +533,6 @@ class Scatter(Pointwise):
)
class ReductionHint(Enum):
INNER = 0
OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
class TileHint(Enum):
SQUARE = 0
DEFAULT = 1
REDUCTION_COMBINE_FN = {
"any": ops_wrapper("logical_or"),
"max": ops_wrapper("maximum"),

View File

@ -1,12 +0,0 @@
from . import triton_helpers, triton_heuristics
from .triton_helpers import libdevice, math as tl_math
from .triton_heuristics import AutotuneHint
__all__ = [
"triton_heuristics",
"triton_helpers",
"libdevice",
"tl_math",
"AutotuneHint",
]

View File

@ -0,0 +1,82 @@
import collections
from dataclasses import fields
from enum import auto, Enum
class ReductionHint(Enum):
INNER = 0
OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
class TileHint(Enum):
SQUARE = 0
DEFAULT = 1
# Attempt to import AttrsDescriptor from Triton
try:
from triton.compiler.compiler import AttrsDescriptor
attrs_descriptor_available = True
# Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
divisible_by_8_available = "divisible_by_8" in attr_desc_fields
except ImportError:
attrs_descriptor_available = False
# Define `instance_descriptor` function with clear conditional handling
if attrs_descriptor_available:
def instance_descriptor(
divisible_by_16=None,
equal_to_1=None,
ids_of_folded_args=None,
divisible_by_8=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"divisible_by_16": divisible_by_16,
"equal_to_1": equal_to_1,
}
# Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
if ids_of_folded_args_available:
kwargs["ids_of_folded_args"] = ids_of_folded_args
if divisible_by_8_available:
kwargs["divisible_by_8"] = divisible_by_8
# Instantiate AttrsDescriptor with the prepared arguments
return AttrsDescriptor(**kwargs)
else:
# Define a namedtuple as a fallback when AttrsDescriptor is not available
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
"instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[tuple(), tuple(), tuple(), tuple()],
)
_NUM_THREADS_PER_WARP = 32
class HeuristicType(Enum):
PERSISTENT_REDUCTION = auto()
POINTWISE = auto()
REDUCTION = auto()
SPLIT_SCAN = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()
class AutotuneHint(Enum):
ELEMENTS_PER_WARP_32 = 0
# Triton codegen tries to codegen set of AutotuneHints.
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
# which isn't valid python.
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
__repr__ = Enum.__str__

View File

@ -12,7 +12,6 @@ import os.path
import re
import threading
import time
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import torch
@ -24,8 +23,6 @@ from torch._dynamo.utils import dynamo_timed, get_first_attr
from torch._inductor import config
from torch._inductor.codecache import cache_dir, CudaKernelParamCache
from torch._inductor.coordinate_descent_tuner import CoordescTuner
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.utils import (
ceildiv,
conditional_product,
@ -37,6 +34,13 @@ from torch._inductor.utils import (
triton_config_to_hashable,
)
from torch.utils._triton import has_triton_package
from .hints import (
_NUM_THREADS_PER_WARP,
AutotuneHint,
HeuristicType,
ReductionHint,
TileHint,
)
log = logging.getLogger(__name__)
@ -59,28 +63,6 @@ else:
ASTSource = None
_NUM_THREADS_PER_WARP = 32
class HeuristicType(Enum):
PERSISTENT_REDUCTION = auto()
POINTWISE = auto()
REDUCTION = auto()
SPLIT_SCAN = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()
class AutotuneHint(Enum):
ELEMENTS_PER_WARP_32 = 0
# Triton codegen tries to codegen set of AutotuneHints.
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
# which isn't valid python.
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
__repr__ = Enum.__str__
def autotune_hints_to_configs(
hints: Set[AutotuneHint], size_hints, block_size: int
) -> List[Config]:

View File

@ -21,7 +21,6 @@ import tempfile
import textwrap
import time
import unittest
from dataclasses import fields
from datetime import datetime
from io import StringIO
from typing import (
@ -689,51 +688,6 @@ def output_node(gm: torch.fx.GraphModule):
return last_node
# Attempt to import AttrsDescriptor from Triton
try:
from triton.compiler.compiler import AttrsDescriptor
attrs_descriptor_available = True
# Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
divisible_by_8_available = "divisible_by_8" in attr_desc_fields
except ImportError:
attrs_descriptor_available = False
# Define `instance_descriptor` function with clear conditional handling
if attrs_descriptor_available:
def instance_descriptor(
divisible_by_16=None,
equal_to_1=None,
ids_of_folded_args=None,
divisible_by_8=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"divisible_by_16": divisible_by_16,
"equal_to_1": equal_to_1,
}
# Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
if ids_of_folded_args_available:
kwargs["ids_of_folded_args"] = ids_of_folded_args
if divisible_by_8_available:
kwargs["divisible_by_8"] = divisible_by_8
# Instantiate AttrsDescriptor with the prepared arguments
return AttrsDescriptor(**kwargs)
else:
# Define a namedtuple as a fallback when AttrsDescriptor is not available
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
"instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[tuple(), tuple(), tuple(), tuple()],
)
_registered_caches: List[Any] = []