[inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557)

I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124557
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552, #124553
This commit is contained in:
Jason Ansel
2024-04-21 11:09:45 -07:00
committed by PyTorch MergeBot
parent f4d47f5bbb
commit fcf28b0ad5
20 changed files with 196 additions and 178 deletions

View File

@ -1,6 +1,6 @@
import torch
from torch._inductor import ir
from torch._inductor.utils import do_bench
from torch._inductor.runtime.runtime_utils import do_bench
def to_channels_last(x):

View File

@ -13,8 +13,8 @@ from torch._dynamo.testing import rand_strided, same
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.exc import CppWrapperCodeGenError
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import TestCase
from torch._inductor.utils import cache_dir
from torch.export import Dim, export
from torch.testing import FileCheck

View File

@ -19,8 +19,9 @@ from torch._inductor.codecache import (
TensorMetadata,
TensorMetadataAndValues,
)
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import cache_dir, fresh_inductor_cache
from torch._inductor.utils import fresh_inductor_cache
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import largeTensorTest
from torch.testing._internal.common_utils import (

View File

@ -4,11 +4,11 @@ import functools
import logging
import torch
from torch._inductor.runtime.runtime_utils import do_bench
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import do_bench, do_bench_using_profiling
from torch._inductor.utils import do_bench_using_profiling
log = logging.getLogger(__name__)

View File

@ -12,7 +12,8 @@ from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss
from torch._dynamo.utils import maybe_cprofile
from torch._inductor import config, ir, metrics
from torch._inductor.fx_passes import pad_mm as pad_mm_pass
from torch._inductor.utils import do_bench, run_and_get_code
from torch._inductor.runtime.runtime_utils import do_bench
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.inductor_utils import HAS_CUDA
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"

View File

@ -35,7 +35,7 @@ if TYPE_CHECKING:
from torch._inductor.select_algorithm import TritonTemplateCaller
from . import config
from .utils import do_bench
from .runtime.runtime_utils import do_bench
from .virtualized import V
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"

View File

@ -59,7 +59,8 @@ from torch._dynamo.device_interface import (
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.utils import cache_dir, clear_on_fresh_inductor_cache, is_linux
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import clear_on_fresh_inductor_cache, is_linux
from torch._subclasses.fake_tensor import (
extract_tensor_metadata,
FakeTensor,

View File

@ -8,10 +8,10 @@ from typing import Any, List, Optional
import sympy
import torch
from ...codecache import cache_dir
from ...config import cuda as inductor_cuda_config
from ...ir import Layout
from ...runtime.runtime_utils import cache_dir
from .cuda_env import get_cuda_arch, get_cuda_version
log = logging.getLogger(__name__)

View File

@ -6,7 +6,8 @@ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from .. import config
from ..codecache import PyCodeCache, TritonFuture
from ..utils import cache_on_self, do_bench
from ..runtime.runtime_utils import do_bench
from ..utils import cache_on_self
from ..virtualized import V
from .common import TensorArg

View File

@ -47,24 +47,26 @@ from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
from ..ir import IRNode, TritonTemplateBuffer
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..runtime.hints import ReductionHint
from ..runtime.runtime_utils import (
do_bench,
get_max_y_grid,
green_text,
next_power_of_2,
yellow_text,
)
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
from ..utils import (
cache_on_self,
do_bench,
get_dtype_size,
get_fused_kernel_name,
get_kernel_metadata,
get_max_y_grid,
green_text,
is_welford_reduction,
next_power_of_2,
Placeholder,
sympy_dot,
sympy_index_symbol,
sympy_product,
sympy_subs,
unique,
yellow_text,
)
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
from ..wrapper_benchmark import get_kernel_category_by_source_code

View File

@ -4,7 +4,7 @@ import logging
from typing import Callable, Optional
from torch.utils._triton import has_triton
from .utils import red_text, triton_config_to_hashable
from .runtime.runtime_utils import red_text, triton_config_to_hashable
if has_triton():
import triton

View File

@ -2,6 +2,7 @@ import functools
from typing import List, Optional, Union
import torch
import torch._inductor.runtime.runtime_utils
from torch import Tensor
from torch._inductor import utils
from torch._subclasses.fake_tensor import FakeTensor
@ -241,7 +242,7 @@ def should_pad_bench(
return False
do_bench = functools.partial(
utils.do_bench,
torch._inductor.runtime.runtime_utils.do_bench,
warmup=5,
)

View File

@ -61,6 +61,7 @@ from .dependencies import (
)
from .ops_handler import OpCounterCSE
from .runtime.hints import ReductionHint
from .runtime.runtime_utils import do_bench
from .utils import (
argsort,
cache_on_self,
@ -68,7 +69,6 @@ from .utils import (
convert_shape_to_inductor,
convert_shape_to_symint,
developer_warning,
do_bench,
get_kernel_metadata,
is_dynamic,
is_gpu,

View File

@ -9,7 +9,8 @@ from torch._inductor.select_algorithm import realize_inputs
from torch._inductor.virtualized import V
from .. import config as inductor_config
from ..utils import ceildiv as cdiv, next_power_of_2
from ..runtime.runtime_utils import next_power_of_2
from ..utils import ceildiv as cdiv
log = logging.getLogger(__name__)

View File

@ -0,0 +1,142 @@
from __future__ import annotations
import functools
import getpass
import inspect
import operator
import os
import re
import tempfile
import torch
def conditional_product(*args):
return functools.reduce(operator.mul, [x for x in args if x])
def ceildiv(numer: int, denom: int) -> int:
return -(numer // -denom)
def next_power_of_2(n: int) -> int:
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
"""
Return the total number of bytes the arguments of tensor type takes.
For in/out args, tensor sizes are counted twice: once for reading and
once for writing.
The first num_in_out_args arguments are in out tensors.
"""
return sum(
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
for i, arg in enumerate(args)
if isinstance(arg, torch.Tensor)
)
def triton_config_to_hashable(cfg):
"""
Convert triton config to a tuple that can uniquely identify it. We can use
the return value as a dictionary key.
"""
items = sorted(cfg.kwargs.items())
items.append(("num_warps", cfg.num_warps))
items.append(("num_stages", cfg.num_stages))
return tuple(items)
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
slow = ms > 0.012 and gb_per_s < 650
return red_text(info_str) if color and slow else info_str
def get_max_y_grid():
return 65535
def do_bench(*args, **kwargs):
@functools.lru_cache(None)
def load_triton():
try:
# NB: Lazily load triton, as importing triton is slow
# see https://github.com/openai/triton/issues/1599
from triton.testing import do_bench as triton_do_bench
except ImportError as exc:
raise NotImplementedError("requires Triton") from exc
# triton PR https://github.com/openai/triton/pull/1513 change the
# quantile fields name from 'percentiles' to 'quantiles'
# and change the default value from (0.5, 0.2, 0.8) to None.
# This may break inductor since a caller expects a tuple may get a item.
#
# Add a wrapper to maintain the same behavior for inductor.
# Maybe we should have own implementation of this function?
return triton_do_bench, (
"quantiles"
if inspect.signature(triton_do_bench).parameters.get("quantiles")
is not None
else "percentiles"
)
triton_do_bench, quantile_field_name = load_triton()
if quantile_field_name not in kwargs:
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
return triton_do_bench(*args, **kwargs)[0]
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = os.path.join(
tempfile.gettempdir(),
"torchinductor_" + sanitized_username,
)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
HAS_COLORAMA = True
try:
import colorama
except ImportError:
HAS_COLORAMA = False
def _color_text(msg, color):
if not HAS_COLORAMA:
return msg
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
def green_text(msg):
return _color_text(msg, "green")
def yellow_text(msg):
return _color_text(msg, "yellow")
def red_text(msg):
return _color_text(msg, "red")
def blue_text(msg):
return _color_text(msg, "blue")

View File

@ -21,9 +21,17 @@ from torch._dynamo.device_interface import DeviceGuard, get_interface_for_device
from torch._dynamo.utils import dynamo_timed, get_first_attr
from torch._inductor import config
from torch._inductor.codecache import cache_dir, CudaKernelParamCache
from torch._inductor.coordinate_descent_tuner import CoordescTuner
from torch._inductor.utils import (
from .hints import (
_NUM_THREADS_PER_WARP,
AutotuneHint,
HeuristicType,
ReductionHint,
TileHint,
)
from .runtime_utils import (
cache_dir,
ceildiv,
conditional_product,
create_bandwidth_info_str,
@ -33,20 +41,13 @@ from torch._inductor.utils import (
next_power_of_2,
triton_config_to_hashable,
)
from torch.utils._triton import has_triton_package
from .hints import (
_NUM_THREADS_PER_WARP,
AutotuneHint,
HeuristicType,
ReductionHint,
TileHint,
)
log = logging.getLogger(__name__)
if has_triton_package():
try:
import triton
except ImportError:
triton = None
if triton is not None:
from triton import Config
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface
@ -57,12 +58,14 @@ if has_triton_package():
ASTSource = None
else:
Config = object
triton = None
KernelInterface = object
OutOfResources = object
ASTSource = None
log = logging.getLogger(__name__)
def autotune_hints_to_configs(
hints: Set[AutotuneHint], size_hints, block_size: int
) -> List[Config]:
@ -681,6 +684,8 @@ class CachingAutotuner(KernelInterface):
"meta": launcher.config.kwargs,
}
from torch._inductor.codecache import CudaKernelParamCache
if torch.version.hip is None:
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
else:

View File

@ -35,6 +35,7 @@ from .codegen.common import get_scheduling_for_device, Kernel
from .comm_analysis import estimate_nccl_collective_runtime
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout
from .runtime.runtime_utils import green_text, red_text
from .sizevars import SimplifyIndexing
from .utils import (
cache_on_self,
@ -44,11 +45,9 @@ from .utils import (
get_device_tflops,
get_dtype_size,
get_gpu_dram_gbps,
green_text,
is_collective,
is_gpu,
is_wait,
red_text,
sympy_product,
)
from .virtualized import V

View File

@ -35,14 +35,8 @@ from .codegen.triton import (
from .codegen.triton_utils import config_of, signature_to_meta
from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .utils import (
do_bench,
get_dtype_size,
Placeholder,
sympy_dot,
sympy_product,
unique,
)
from .runtime.runtime_utils import do_bench
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique
from .virtualized import V
log = logging.getLogger(__name__)

View File

@ -5,7 +5,6 @@ import contextlib
import dataclasses
import enum
import functools
import getpass
import inspect
import io
import itertools
@ -14,7 +13,6 @@ import math
import operator
import os
import platform
import re
import shutil
import sys
import tempfile
@ -51,6 +49,7 @@ from torch.autograd.profiler_util import EventList
from torch.fx.passes.shape_prop import ShapeProp
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
from . import config
from .runtime.runtime_utils import ceildiv as runtime_ceildiv
log = logging.getLogger(__name__)
@ -140,37 +139,6 @@ def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float
return res
def do_bench(*args, **kwargs):
@functools.lru_cache(None)
def load_triton():
try:
# NB: Lazily load triton, as importing triton is slow
# see https://github.com/openai/triton/issues/1599
from triton.testing import do_bench as triton_do_bench
except ImportError as exc:
raise NotImplementedError("requires Triton") from exc
# triton PR https://github.com/openai/triton/pull/1513 change the
# quantile fields name from 'percentiles' to 'quantiles'
# and change the default value from (0.5, 0.2, 0.8) to None.
# This may break inductor since a caller expects a tuple may get a item.
#
# Add a wrapper to maintain the same behavior for inductor.
# Maybe we should have own implementation of this function?
return triton_do_bench, (
"quantiles"
if inspect.signature(triton_do_bench).parameters.get("quantiles")
is not None
else "percentiles"
)
triton_do_bench, quantile_field_name = load_triton()
if quantile_field_name not in kwargs:
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
return triton_do_bench(*args, **kwargs)[0]
@functools.lru_cache(None)
def has_torchvision_roi_align() -> bool:
try:
@ -183,10 +151,6 @@ def has_torchvision_roi_align() -> bool:
return False
def conditional_product(*args):
return functools.reduce(operator.mul, [x for x in args if x])
def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
if device is None:
return torch.tensor(0.0).device # default device
@ -222,20 +186,7 @@ def ceildiv(
assert isinstance(numer, int) and isinstance(
denom, int
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
return -(numer // -denom)
def next_power_of_2(n: int) -> int:
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
return runtime_ceildiv(numer, denom)
def _type_of(key):
@ -703,20 +654,6 @@ def clear_on_fresh_inductor_cache(obj: Any):
return obj
@clear_on_fresh_inductor_cache
@functools.lru_cache(None)
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
cache_dir = os.path.join(
tempfile.gettempdir(),
"torchinductor_" + sanitized_username,
)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
@contextlib.contextmanager
def fresh_inductor_cache(cache_entries=None):
"""
@ -1141,28 +1078,6 @@ def developer_warning(msg):
log.info(msg)
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
"""
Return the total number of bytes the arguments of tensor type takes.
For in/out args, tensor sizes are counted twice: once for reading and
once for writing.
The first num_in_out_args arguments are in out tensors.
"""
return sum(
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
for i, arg in enumerate(args)
if isinstance(arg, torch.Tensor)
)
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
slow = ms > 0.012 and gb_per_s < 650
return red_text(info_str) if color and slow else info_str
def get_benchmark_name():
"""
An experimental API used only when config.benchmark_kernel is true.
@ -1229,17 +1144,6 @@ def maybe_profile(should_profile, *args, **kwargs):
yield
def triton_config_to_hashable(cfg):
"""
Convert triton config to a tuple that can uniquely identify it. We can use
the return value as a dictionary key.
"""
items = sorted(cfg.kwargs.items())
items.append(("num_warps", cfg.num_warps))
items.append(("num_stages", cfg.num_stages))
return tuple(items)
def parallel_num_threads():
threads = config.cpp.threads
if threads < 1:
@ -1247,36 +1151,6 @@ def parallel_num_threads():
return threads
HAS_COLORAMA = True
try:
import colorama
except ImportError:
HAS_COLORAMA = False
def _color_text(msg, color):
if not HAS_COLORAMA:
return msg
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
def green_text(msg):
return _color_text(msg, "green")
def yellow_text(msg):
return _color_text(msg, "yellow")
def red_text(msg):
return _color_text(msg, "red")
def blue_text(msg):
return _color_text(msg, "blue")
@functools.lru_cache(None)
def get_device_tflops(dtype):
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
@ -1320,10 +1194,6 @@ def reduction_num_outputs(reduction_type):
return 3 if is_welford_reduction(reduction_type) else 1
def get_max_y_grid():
return 65535
def is_linux() -> bool:
return platform.system() == "Linux"

View File

@ -4,7 +4,7 @@ from collections import defaultdict
import torch
from torch.autograd import DeviceType
from .utils import create_bandwidth_info_str, do_bench, get_num_bytes
from .runtime.runtime_utils import create_bandwidth_info_str, do_bench, get_num_bytes
_kernel_category_choices = [
"foreach",