mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Refactor runtime files into torch._inductor.runtime (part 2) (#124553)
I am planning to make the compile_worker process not import torch so it can start up much faster. This stack is prep for that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124553 Approved by: https://github.com/yanboliang ghstack dependencies: #124552
This commit is contained in:
committed by
PyTorch MergeBot
parent
480585fd2b
commit
bb8815bc31
@ -381,12 +381,8 @@ class CudaReproTests(TestCase):
|
||||
https://github.com/pytorch/torchdynamo/issues/1670
|
||||
"""
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
from torch._inductor.runtime.triton_heuristics import (
|
||||
CachingAutotuner,
|
||||
grid,
|
||||
HeuristicType,
|
||||
)
|
||||
from torch._inductor.utils import instance_descriptor
|
||||
from torch._inductor.runtime.hints import HeuristicType, instance_descriptor
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid
|
||||
|
||||
def autotune(configs, meta):
|
||||
def decorator(fn):
|
||||
|
@ -34,7 +34,7 @@ import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import preserve_rng_state
|
||||
|
||||
from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
from torch._inductor.runtime.triton_heuristics import AutotuneHint
|
||||
from torch._inductor.runtime.hints import AutotuneHint
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
@ -44,8 +44,9 @@ from ..._dynamo.utils import counters
|
||||
from .. import config, ir, scheduler
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from ..ir import IRNode, ReductionHint, TritonTemplateBuffer
|
||||
from ..ir import IRNode, TritonTemplateBuffer
|
||||
from ..optimize_indexing import indexing_dtype_strength_reduction
|
||||
from ..runtime.hints import ReductionHint
|
||||
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
@ -120,15 +121,9 @@ def gen_common_triton_imports():
|
||||
|
||||
imports.splice(
|
||||
"""
|
||||
from torch._inductor.runtime import (
|
||||
triton_helpers,
|
||||
triton_heuristics,
|
||||
libdevice,
|
||||
tl_math,
|
||||
AutotuneHint,
|
||||
)
|
||||
from torch._inductor.ir import ReductionHint, TileHint
|
||||
from torch._inductor.utils import instance_descriptor
|
||||
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
||||
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
||||
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
|
||||
"""
|
||||
)
|
||||
return imports.getvalue()
|
||||
|
@ -2,7 +2,8 @@ import functools
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
from torch._inductor import config, ir
|
||||
import torch._inductor.runtime.hints
|
||||
from torch._inductor import config
|
||||
|
||||
from torch._inductor.codegen.triton import (
|
||||
IterationRangesRoot,
|
||||
@ -36,7 +37,7 @@ class TritonSplitScanKernel(TritonKernel):
|
||||
*groups,
|
||||
index_dtype: str,
|
||||
mutations: Optional[Set[str]] = None,
|
||||
reduction_hint=ir.ReductionHint.DEFAULT,
|
||||
reduction_hint=torch._inductor.runtime.hints.ReductionHint.DEFAULT,
|
||||
min_elem_per_thread=0,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -5,7 +5,8 @@ import sympy
|
||||
import torch
|
||||
|
||||
from .. import config
|
||||
from ..utils import _type_of, instance_descriptor
|
||||
from ..runtime.hints import instance_descriptor
|
||||
from ..utils import _type_of
|
||||
from ..virtualized import V
|
||||
from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg
|
||||
|
||||
|
@ -8,7 +8,6 @@ import re
|
||||
import textwrap
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
@ -61,6 +60,7 @@ from .dependencies import (
|
||||
var_builder,
|
||||
)
|
||||
from .ops_handler import OpCounterCSE
|
||||
from .runtime.hints import ReductionHint
|
||||
from .utils import (
|
||||
argsort,
|
||||
cache_on_self,
|
||||
@ -533,18 +533,6 @@ class Scatter(Pointwise):
|
||||
)
|
||||
|
||||
|
||||
class ReductionHint(Enum):
|
||||
INNER = 0
|
||||
OUTER = 1
|
||||
OUTER_TINY = 2
|
||||
DEFAULT = 3
|
||||
|
||||
|
||||
class TileHint(Enum):
|
||||
SQUARE = 0
|
||||
DEFAULT = 1
|
||||
|
||||
|
||||
REDUCTION_COMBINE_FN = {
|
||||
"any": ops_wrapper("logical_or"),
|
||||
"max": ops_wrapper("maximum"),
|
||||
|
@ -1,12 +0,0 @@
|
||||
from . import triton_helpers, triton_heuristics
|
||||
from .triton_helpers import libdevice, math as tl_math
|
||||
from .triton_heuristics import AutotuneHint
|
||||
|
||||
|
||||
__all__ = [
|
||||
"triton_heuristics",
|
||||
"triton_helpers",
|
||||
"libdevice",
|
||||
"tl_math",
|
||||
"AutotuneHint",
|
||||
]
|
||||
|
82
torch/_inductor/runtime/hints.py
Normal file
82
torch/_inductor/runtime/hints.py
Normal file
@ -0,0 +1,82 @@
|
||||
import collections
|
||||
from dataclasses import fields
|
||||
from enum import auto, Enum
|
||||
|
||||
|
||||
class ReductionHint(Enum):
|
||||
INNER = 0
|
||||
OUTER = 1
|
||||
OUTER_TINY = 2
|
||||
DEFAULT = 3
|
||||
|
||||
|
||||
class TileHint(Enum):
|
||||
SQUARE = 0
|
||||
DEFAULT = 1
|
||||
|
||||
|
||||
# Attempt to import AttrsDescriptor from Triton
|
||||
try:
|
||||
from triton.compiler.compiler import AttrsDescriptor
|
||||
|
||||
attrs_descriptor_available = True
|
||||
# Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
|
||||
attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
|
||||
ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
|
||||
divisible_by_8_available = "divisible_by_8" in attr_desc_fields
|
||||
except ImportError:
|
||||
attrs_descriptor_available = False
|
||||
|
||||
# Define `instance_descriptor` function with clear conditional handling
|
||||
if attrs_descriptor_available:
|
||||
|
||||
def instance_descriptor(
|
||||
divisible_by_16=None,
|
||||
equal_to_1=None,
|
||||
ids_of_folded_args=None,
|
||||
divisible_by_8=None,
|
||||
):
|
||||
# Prepare the arguments for AttrsDescriptor
|
||||
kwargs = {
|
||||
"divisible_by_16": divisible_by_16,
|
||||
"equal_to_1": equal_to_1,
|
||||
}
|
||||
|
||||
# Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
|
||||
if ids_of_folded_args_available:
|
||||
kwargs["ids_of_folded_args"] = ids_of_folded_args
|
||||
if divisible_by_8_available:
|
||||
kwargs["divisible_by_8"] = divisible_by_8
|
||||
|
||||
# Instantiate AttrsDescriptor with the prepared arguments
|
||||
return AttrsDescriptor(**kwargs)
|
||||
|
||||
else:
|
||||
# Define a namedtuple as a fallback when AttrsDescriptor is not available
|
||||
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
|
||||
"instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[tuple(), tuple(), tuple(), tuple()],
|
||||
)
|
||||
|
||||
|
||||
_NUM_THREADS_PER_WARP = 32
|
||||
|
||||
|
||||
class HeuristicType(Enum):
|
||||
PERSISTENT_REDUCTION = auto()
|
||||
POINTWISE = auto()
|
||||
REDUCTION = auto()
|
||||
SPLIT_SCAN = auto()
|
||||
TEMPLATE = auto()
|
||||
USER_AUTOTUNE = auto()
|
||||
|
||||
|
||||
class AutotuneHint(Enum):
|
||||
ELEMENTS_PER_WARP_32 = 0
|
||||
|
||||
# Triton codegen tries to codegen set of AutotuneHints.
|
||||
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
|
||||
# which isn't valid python.
|
||||
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
|
||||
__repr__ = Enum.__str__
|
@ -12,7 +12,6 @@ import os.path
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
@ -24,8 +23,6 @@ from torch._dynamo.utils import dynamo_timed, get_first_attr
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import cache_dir, CudaKernelParamCache
|
||||
from torch._inductor.coordinate_descent_tuner import CoordescTuner
|
||||
|
||||
from torch._inductor.ir import ReductionHint, TileHint
|
||||
from torch._inductor.utils import (
|
||||
ceildiv,
|
||||
conditional_product,
|
||||
@ -37,6 +34,13 @@ from torch._inductor.utils import (
|
||||
triton_config_to_hashable,
|
||||
)
|
||||
from torch.utils._triton import has_triton_package
|
||||
from .hints import (
|
||||
_NUM_THREADS_PER_WARP,
|
||||
AutotuneHint,
|
||||
HeuristicType,
|
||||
ReductionHint,
|
||||
TileHint,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -59,28 +63,6 @@ else:
|
||||
ASTSource = None
|
||||
|
||||
|
||||
_NUM_THREADS_PER_WARP = 32
|
||||
|
||||
|
||||
class HeuristicType(Enum):
|
||||
PERSISTENT_REDUCTION = auto()
|
||||
POINTWISE = auto()
|
||||
REDUCTION = auto()
|
||||
SPLIT_SCAN = auto()
|
||||
TEMPLATE = auto()
|
||||
USER_AUTOTUNE = auto()
|
||||
|
||||
|
||||
class AutotuneHint(Enum):
|
||||
ELEMENTS_PER_WARP_32 = 0
|
||||
|
||||
# Triton codegen tries to codegen set of AutotuneHints.
|
||||
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
|
||||
# which isn't valid python.
|
||||
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
|
||||
__repr__ = Enum.__str__
|
||||
|
||||
|
||||
def autotune_hints_to_configs(
|
||||
hints: Set[AutotuneHint], size_hints, block_size: int
|
||||
) -> List[Config]:
|
||||
|
@ -21,7 +21,6 @@ import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import unittest
|
||||
from dataclasses import fields
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import (
|
||||
@ -689,51 +688,6 @@ def output_node(gm: torch.fx.GraphModule):
|
||||
return last_node
|
||||
|
||||
|
||||
# Attempt to import AttrsDescriptor from Triton
|
||||
try:
|
||||
from triton.compiler.compiler import AttrsDescriptor
|
||||
|
||||
attrs_descriptor_available = True
|
||||
# Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
|
||||
attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
|
||||
ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
|
||||
divisible_by_8_available = "divisible_by_8" in attr_desc_fields
|
||||
except ImportError:
|
||||
attrs_descriptor_available = False
|
||||
|
||||
# Define `instance_descriptor` function with clear conditional handling
|
||||
if attrs_descriptor_available:
|
||||
|
||||
def instance_descriptor(
|
||||
divisible_by_16=None,
|
||||
equal_to_1=None,
|
||||
ids_of_folded_args=None,
|
||||
divisible_by_8=None,
|
||||
):
|
||||
# Prepare the arguments for AttrsDescriptor
|
||||
kwargs = {
|
||||
"divisible_by_16": divisible_by_16,
|
||||
"equal_to_1": equal_to_1,
|
||||
}
|
||||
|
||||
# Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
|
||||
if ids_of_folded_args_available:
|
||||
kwargs["ids_of_folded_args"] = ids_of_folded_args
|
||||
if divisible_by_8_available:
|
||||
kwargs["divisible_by_8"] = divisible_by_8
|
||||
|
||||
# Instantiate AttrsDescriptor with the prepared arguments
|
||||
return AttrsDescriptor(**kwargs)
|
||||
|
||||
else:
|
||||
# Define a namedtuple as a fallback when AttrsDescriptor is not available
|
||||
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
|
||||
"instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[tuple(), tuple(), tuple(), tuple()],
|
||||
)
|
||||
|
||||
|
||||
_registered_caches: List[Any] = []
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user