mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557)
I am planning to make the compile_worker process not import torch so it can start up much faster. This stack is prep for that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124557 Approved by: https://github.com/yanboliang ghstack dependencies: #124552, #124553
This commit is contained in:
committed by
PyTorch MergeBot
parent
f4d47f5bbb
commit
fcf28b0ad5
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch._inductor import ir
|
||||
from torch._inductor.utils import do_bench
|
||||
from torch._inductor.runtime.runtime_utils import do_bench
|
||||
|
||||
|
||||
def to_channels_last(x):
|
||||
|
@ -13,8 +13,8 @@ from torch._dynamo.testing import rand_strided, same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.exc import CppWrapperCodeGenError
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import cache_dir
|
||||
|
||||
from torch.export import Dim, export
|
||||
from torch.testing import FileCheck
|
||||
|
@ -19,8 +19,9 @@ from torch._inductor.codecache import (
|
||||
TensorMetadata,
|
||||
TensorMetadataAndValues,
|
||||
)
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import cache_dir, fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_device_type import largeTensorTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
@ -4,11 +4,11 @@ import functools
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch._inductor.runtime.runtime_utils import do_bench
|
||||
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
||||
from torch._inductor.utils import do_bench, do_bench_using_profiling
|
||||
|
||||
from torch._inductor.utils import do_bench_using_profiling
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -12,7 +12,8 @@ from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss
|
||||
from torch._dynamo.utils import maybe_cprofile
|
||||
from torch._inductor import config, ir, metrics
|
||||
from torch._inductor.fx_passes import pad_mm as pad_mm_pass
|
||||
from torch._inductor.utils import do_bench, run_and_get_code
|
||||
from torch._inductor.runtime.runtime_utils import do_bench
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
|
||||
|
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.select_algorithm import TritonTemplateCaller
|
||||
|
||||
from . import config
|
||||
from .utils import do_bench
|
||||
from .runtime.runtime_utils import do_bench
|
||||
from .virtualized import V
|
||||
|
||||
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
||||
|
@ -59,7 +59,8 @@ from torch._dynamo.device_interface import (
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor import config, exc, metrics
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.utils import cache_dir, clear_on_fresh_inductor_cache, is_linux
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache, is_linux
|
||||
from torch._subclasses.fake_tensor import (
|
||||
extract_tensor_metadata,
|
||||
FakeTensor,
|
||||
|
@ -8,10 +8,10 @@ from typing import Any, List, Optional
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from ...codecache import cache_dir
|
||||
from ...config import cuda as inductor_cuda_config
|
||||
from ...ir import Layout
|
||||
|
||||
from ...runtime.runtime_utils import cache_dir
|
||||
from .cuda_env import get_cuda_arch, get_cuda_version
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -6,7 +6,8 @@ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
|
||||
from .. import config
|
||||
from ..codecache import PyCodeCache, TritonFuture
|
||||
from ..utils import cache_on_self, do_bench
|
||||
from ..runtime.runtime_utils import do_bench
|
||||
from ..utils import cache_on_self
|
||||
from ..virtualized import V
|
||||
from .common import TensorArg
|
||||
|
||||
|
@ -47,24 +47,26 @@ from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from ..ir import IRNode, TritonTemplateBuffer
|
||||
from ..optimize_indexing import indexing_dtype_strength_reduction
|
||||
from ..runtime.hints import ReductionHint
|
||||
from ..runtime.runtime_utils import (
|
||||
do_bench,
|
||||
get_max_y_grid,
|
||||
green_text,
|
||||
next_power_of_2,
|
||||
yellow_text,
|
||||
)
|
||||
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
do_bench,
|
||||
get_dtype_size,
|
||||
get_fused_kernel_name,
|
||||
get_kernel_metadata,
|
||||
get_max_y_grid,
|
||||
green_text,
|
||||
is_welford_reduction,
|
||||
next_power_of_2,
|
||||
Placeholder,
|
||||
sympy_dot,
|
||||
sympy_index_symbol,
|
||||
sympy_product,
|
||||
sympy_subs,
|
||||
unique,
|
||||
yellow_text,
|
||||
)
|
||||
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
|
||||
from ..wrapper_benchmark import get_kernel_category_by_source_code
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch.utils._triton import has_triton
|
||||
from .utils import red_text, triton_config_to_hashable
|
||||
from .runtime.runtime_utils import red_text, triton_config_to_hashable
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
|
@ -2,6 +2,7 @@ import functools
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.runtime.runtime_utils
|
||||
from torch import Tensor
|
||||
from torch._inductor import utils
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
@ -241,7 +242,7 @@ def should_pad_bench(
|
||||
return False
|
||||
|
||||
do_bench = functools.partial(
|
||||
utils.do_bench,
|
||||
torch._inductor.runtime.runtime_utils.do_bench,
|
||||
warmup=5,
|
||||
)
|
||||
|
||||
|
@ -61,6 +61,7 @@ from .dependencies import (
|
||||
)
|
||||
from .ops_handler import OpCounterCSE
|
||||
from .runtime.hints import ReductionHint
|
||||
from .runtime.runtime_utils import do_bench
|
||||
from .utils import (
|
||||
argsort,
|
||||
cache_on_self,
|
||||
@ -68,7 +69,6 @@ from .utils import (
|
||||
convert_shape_to_inductor,
|
||||
convert_shape_to_symint,
|
||||
developer_warning,
|
||||
do_bench,
|
||||
get_kernel_metadata,
|
||||
is_dynamic,
|
||||
is_gpu,
|
||||
|
@ -9,7 +9,8 @@ from torch._inductor.select_algorithm import realize_inputs
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from .. import config as inductor_config
|
||||
from ..utils import ceildiv as cdiv, next_power_of_2
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..utils import ceildiv as cdiv
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
142
torch/_inductor/runtime/runtime_utils.py
Normal file
142
torch/_inductor/runtime/runtime_utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import getpass
|
||||
import inspect
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def conditional_product(*args):
|
||||
return functools.reduce(operator.mul, [x for x in args if x])
|
||||
|
||||
|
||||
def ceildiv(numer: int, denom: int) -> int:
|
||||
return -(numer // -denom)
|
||||
|
||||
|
||||
def next_power_of_2(n: int) -> int:
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n |= n >> 32
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
|
||||
"""
|
||||
Return the total number of bytes the arguments of tensor type takes.
|
||||
|
||||
For in/out args, tensor sizes are counted twice: once for reading and
|
||||
once for writing.
|
||||
|
||||
The first num_in_out_args arguments are in out tensors.
|
||||
"""
|
||||
return sum(
|
||||
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
|
||||
for i, arg in enumerate(args)
|
||||
if isinstance(arg, torch.Tensor)
|
||||
)
|
||||
|
||||
|
||||
def triton_config_to_hashable(cfg):
|
||||
"""
|
||||
Convert triton config to a tuple that can uniquely identify it. We can use
|
||||
the return value as a dictionary key.
|
||||
"""
|
||||
items = sorted(cfg.kwargs.items())
|
||||
items.append(("num_warps", cfg.num_warps))
|
||||
items.append(("num_stages", cfg.num_stages))
|
||||
return tuple(items)
|
||||
|
||||
|
||||
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
|
||||
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
|
||||
slow = ms > 0.012 and gb_per_s < 650
|
||||
return red_text(info_str) if color and slow else info_str
|
||||
|
||||
|
||||
def get_max_y_grid():
|
||||
return 65535
|
||||
|
||||
|
||||
def do_bench(*args, **kwargs):
|
||||
@functools.lru_cache(None)
|
||||
def load_triton():
|
||||
try:
|
||||
# NB: Lazily load triton, as importing triton is slow
|
||||
# see https://github.com/openai/triton/issues/1599
|
||||
from triton.testing import do_bench as triton_do_bench
|
||||
except ImportError as exc:
|
||||
raise NotImplementedError("requires Triton") from exc
|
||||
|
||||
# triton PR https://github.com/openai/triton/pull/1513 change the
|
||||
# quantile fields name from 'percentiles' to 'quantiles'
|
||||
# and change the default value from (0.5, 0.2, 0.8) to None.
|
||||
# This may break inductor since a caller expects a tuple may get a item.
|
||||
#
|
||||
# Add a wrapper to maintain the same behavior for inductor.
|
||||
# Maybe we should have own implementation of this function?
|
||||
return triton_do_bench, (
|
||||
"quantiles"
|
||||
if inspect.signature(triton_do_bench).parameters.get("quantiles")
|
||||
is not None
|
||||
else "percentiles"
|
||||
)
|
||||
|
||||
triton_do_bench, quantile_field_name = load_triton()
|
||||
|
||||
if quantile_field_name not in kwargs:
|
||||
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
|
||||
return triton_do_bench(*args, **kwargs)[0]
|
||||
|
||||
|
||||
def cache_dir() -> str:
|
||||
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||
if cache_dir is None:
|
||||
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
"torchinductor_" + sanitized_username,
|
||||
)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
HAS_COLORAMA = True
|
||||
try:
|
||||
import colorama
|
||||
except ImportError:
|
||||
HAS_COLORAMA = False
|
||||
|
||||
|
||||
def _color_text(msg, color):
|
||||
if not HAS_COLORAMA:
|
||||
return msg
|
||||
|
||||
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
|
||||
|
||||
|
||||
def green_text(msg):
|
||||
return _color_text(msg, "green")
|
||||
|
||||
|
||||
def yellow_text(msg):
|
||||
return _color_text(msg, "yellow")
|
||||
|
||||
|
||||
def red_text(msg):
|
||||
return _color_text(msg, "red")
|
||||
|
||||
|
||||
def blue_text(msg):
|
||||
return _color_text(msg, "blue")
|
@ -21,9 +21,17 @@ from torch._dynamo.device_interface import DeviceGuard, get_interface_for_device
|
||||
from torch._dynamo.utils import dynamo_timed, get_first_attr
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import cache_dir, CudaKernelParamCache
|
||||
from torch._inductor.coordinate_descent_tuner import CoordescTuner
|
||||
from torch._inductor.utils import (
|
||||
from .hints import (
|
||||
_NUM_THREADS_PER_WARP,
|
||||
AutotuneHint,
|
||||
HeuristicType,
|
||||
ReductionHint,
|
||||
TileHint,
|
||||
)
|
||||
|
||||
from .runtime_utils import (
|
||||
cache_dir,
|
||||
ceildiv,
|
||||
conditional_product,
|
||||
create_bandwidth_info_str,
|
||||
@ -33,20 +41,13 @@ from torch._inductor.utils import (
|
||||
next_power_of_2,
|
||||
triton_config_to_hashable,
|
||||
)
|
||||
from torch.utils._triton import has_triton_package
|
||||
from .hints import (
|
||||
_NUM_THREADS_PER_WARP,
|
||||
AutotuneHint,
|
||||
HeuristicType,
|
||||
ReductionHint,
|
||||
TileHint,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
if has_triton_package():
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
triton = None
|
||||
|
||||
if triton is not None:
|
||||
from triton import Config
|
||||
from triton.runtime.autotuner import OutOfResources
|
||||
from triton.runtime.jit import KernelInterface
|
||||
@ -57,12 +58,14 @@ if has_triton_package():
|
||||
ASTSource = None
|
||||
else:
|
||||
Config = object
|
||||
triton = None
|
||||
KernelInterface = object
|
||||
OutOfResources = object
|
||||
ASTSource = None
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def autotune_hints_to_configs(
|
||||
hints: Set[AutotuneHint], size_hints, block_size: int
|
||||
) -> List[Config]:
|
||||
@ -681,6 +684,8 @@ class CachingAutotuner(KernelInterface):
|
||||
"meta": launcher.config.kwargs,
|
||||
}
|
||||
|
||||
from torch._inductor.codecache import CudaKernelParamCache
|
||||
|
||||
if torch.version.hip is None:
|
||||
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
|
||||
else:
|
||||
|
@ -35,6 +35,7 @@ from .codegen.common import get_scheduling_for_device, Kernel
|
||||
from .comm_analysis import estimate_nccl_collective_runtime
|
||||
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout
|
||||
from .runtime.runtime_utils import green_text, red_text
|
||||
from .sizevars import SimplifyIndexing
|
||||
from .utils import (
|
||||
cache_on_self,
|
||||
@ -44,11 +45,9 @@ from .utils import (
|
||||
get_device_tflops,
|
||||
get_dtype_size,
|
||||
get_gpu_dram_gbps,
|
||||
green_text,
|
||||
is_collective,
|
||||
is_gpu,
|
||||
is_wait,
|
||||
red_text,
|
||||
sympy_product,
|
||||
)
|
||||
from .virtualized import V
|
||||
|
@ -35,14 +35,8 @@ from .codegen.triton import (
|
||||
from .codegen.triton_utils import config_of, signature_to_meta
|
||||
from .exc import CUDACompileError
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .utils import (
|
||||
do_bench,
|
||||
get_dtype_size,
|
||||
Placeholder,
|
||||
sympy_dot,
|
||||
sympy_product,
|
||||
unique,
|
||||
)
|
||||
from .runtime.runtime_utils import do_bench
|
||||
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique
|
||||
from .virtualized import V
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -5,7 +5,6 @@ import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import getpass
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
@ -14,7 +13,6 @@ import math
|
||||
import operator
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
@ -51,6 +49,7 @@ from torch.autograd.profiler_util import EventList
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
|
||||
from . import config
|
||||
from .runtime.runtime_utils import ceildiv as runtime_ceildiv
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -140,37 +139,6 @@ def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float
|
||||
return res
|
||||
|
||||
|
||||
def do_bench(*args, **kwargs):
|
||||
@functools.lru_cache(None)
|
||||
def load_triton():
|
||||
try:
|
||||
# NB: Lazily load triton, as importing triton is slow
|
||||
# see https://github.com/openai/triton/issues/1599
|
||||
from triton.testing import do_bench as triton_do_bench
|
||||
except ImportError as exc:
|
||||
raise NotImplementedError("requires Triton") from exc
|
||||
|
||||
# triton PR https://github.com/openai/triton/pull/1513 change the
|
||||
# quantile fields name from 'percentiles' to 'quantiles'
|
||||
# and change the default value from (0.5, 0.2, 0.8) to None.
|
||||
# This may break inductor since a caller expects a tuple may get a item.
|
||||
#
|
||||
# Add a wrapper to maintain the same behavior for inductor.
|
||||
# Maybe we should have own implementation of this function?
|
||||
return triton_do_bench, (
|
||||
"quantiles"
|
||||
if inspect.signature(triton_do_bench).parameters.get("quantiles")
|
||||
is not None
|
||||
else "percentiles"
|
||||
)
|
||||
|
||||
triton_do_bench, quantile_field_name = load_triton()
|
||||
|
||||
if quantile_field_name not in kwargs:
|
||||
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
|
||||
return triton_do_bench(*args, **kwargs)[0]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_torchvision_roi_align() -> bool:
|
||||
try:
|
||||
@ -183,10 +151,6 @@ def has_torchvision_roi_align() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def conditional_product(*args):
|
||||
return functools.reduce(operator.mul, [x for x in args if x])
|
||||
|
||||
|
||||
def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
|
||||
if device is None:
|
||||
return torch.tensor(0.0).device # default device
|
||||
@ -222,20 +186,7 @@ def ceildiv(
|
||||
assert isinstance(numer, int) and isinstance(
|
||||
denom, int
|
||||
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
|
||||
return -(numer // -denom)
|
||||
|
||||
|
||||
def next_power_of_2(n: int) -> int:
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n |= n >> 32
|
||||
n += 1
|
||||
return n
|
||||
return runtime_ceildiv(numer, denom)
|
||||
|
||||
|
||||
def _type_of(key):
|
||||
@ -703,20 +654,6 @@ def clear_on_fresh_inductor_cache(obj: Any):
|
||||
return obj
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@functools.lru_cache(None)
|
||||
def cache_dir() -> str:
|
||||
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||
if cache_dir is None:
|
||||
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
|
||||
cache_dir = os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
"torchinductor_" + sanitized_username,
|
||||
)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fresh_inductor_cache(cache_entries=None):
|
||||
"""
|
||||
@ -1141,28 +1078,6 @@ def developer_warning(msg):
|
||||
log.info(msg)
|
||||
|
||||
|
||||
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
|
||||
"""
|
||||
Return the total number of bytes the arguments of tensor type takes.
|
||||
|
||||
For in/out args, tensor sizes are counted twice: once for reading and
|
||||
once for writing.
|
||||
|
||||
The first num_in_out_args arguments are in out tensors.
|
||||
"""
|
||||
return sum(
|
||||
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
|
||||
for i, arg in enumerate(args)
|
||||
if isinstance(arg, torch.Tensor)
|
||||
)
|
||||
|
||||
|
||||
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
|
||||
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
|
||||
slow = ms > 0.012 and gb_per_s < 650
|
||||
return red_text(info_str) if color and slow else info_str
|
||||
|
||||
|
||||
def get_benchmark_name():
|
||||
"""
|
||||
An experimental API used only when config.benchmark_kernel is true.
|
||||
@ -1229,17 +1144,6 @@ def maybe_profile(should_profile, *args, **kwargs):
|
||||
yield
|
||||
|
||||
|
||||
def triton_config_to_hashable(cfg):
|
||||
"""
|
||||
Convert triton config to a tuple that can uniquely identify it. We can use
|
||||
the return value as a dictionary key.
|
||||
"""
|
||||
items = sorted(cfg.kwargs.items())
|
||||
items.append(("num_warps", cfg.num_warps))
|
||||
items.append(("num_stages", cfg.num_stages))
|
||||
return tuple(items)
|
||||
|
||||
|
||||
def parallel_num_threads():
|
||||
threads = config.cpp.threads
|
||||
if threads < 1:
|
||||
@ -1247,36 +1151,6 @@ def parallel_num_threads():
|
||||
return threads
|
||||
|
||||
|
||||
HAS_COLORAMA = True
|
||||
try:
|
||||
import colorama
|
||||
except ImportError:
|
||||
HAS_COLORAMA = False
|
||||
|
||||
|
||||
def _color_text(msg, color):
|
||||
if not HAS_COLORAMA:
|
||||
return msg
|
||||
|
||||
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
|
||||
|
||||
|
||||
def green_text(msg):
|
||||
return _color_text(msg, "green")
|
||||
|
||||
|
||||
def yellow_text(msg):
|
||||
return _color_text(msg, "yellow")
|
||||
|
||||
|
||||
def red_text(msg):
|
||||
return _color_text(msg, "red")
|
||||
|
||||
|
||||
def blue_text(msg):
|
||||
return _color_text(msg, "blue")
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_device_tflops(dtype):
|
||||
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
@ -1320,10 +1194,6 @@ def reduction_num_outputs(reduction_type):
|
||||
return 3 if is_welford_reduction(reduction_type) else 1
|
||||
|
||||
|
||||
def get_max_y_grid():
|
||||
return 65535
|
||||
|
||||
|
||||
def is_linux() -> bool:
|
||||
return platform.system() == "Linux"
|
||||
|
||||
|
@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.autograd import DeviceType
|
||||
from .utils import create_bandwidth_info_str, do_bench, get_num_bytes
|
||||
from .runtime.runtime_utils import create_bandwidth_info_str, do_bench, get_num_bytes
|
||||
|
||||
_kernel_category_choices = [
|
||||
"foreach",
|
||||
|
Reference in New Issue
Block a user