Files
pytorch/torch/_inductor/runtime/runtime_utils.py
Jason Ansel fcf28b0ad5 [inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557)
I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124557
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552, #124553
2024-04-22 04:51:15 +00:00

143 lines
3.8 KiB
Python

from __future__ import annotations
import functools
import getpass
import inspect
import operator
import os
import re
import tempfile
import torch
def conditional_product(*args):
return functools.reduce(operator.mul, [x for x in args if x])
def ceildiv(numer: int, denom: int) -> int:
return -(numer // -denom)
def next_power_of_2(n: int) -> int:
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
"""
Return the total number of bytes the arguments of tensor type takes.
For in/out args, tensor sizes are counted twice: once for reading and
once for writing.
The first num_in_out_args arguments are in out tensors.
"""
return sum(
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
for i, arg in enumerate(args)
if isinstance(arg, torch.Tensor)
)
def triton_config_to_hashable(cfg):
"""
Convert triton config to a tuple that can uniquely identify it. We can use
the return value as a dictionary key.
"""
items = sorted(cfg.kwargs.items())
items.append(("num_warps", cfg.num_warps))
items.append(("num_stages", cfg.num_stages))
return tuple(items)
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
slow = ms > 0.012 and gb_per_s < 650
return red_text(info_str) if color and slow else info_str
def get_max_y_grid():
return 65535
def do_bench(*args, **kwargs):
@functools.lru_cache(None)
def load_triton():
try:
# NB: Lazily load triton, as importing triton is slow
# see https://github.com/openai/triton/issues/1599
from triton.testing import do_bench as triton_do_bench
except ImportError as exc:
raise NotImplementedError("requires Triton") from exc
# triton PR https://github.com/openai/triton/pull/1513 change the
# quantile fields name from 'percentiles' to 'quantiles'
# and change the default value from (0.5, 0.2, 0.8) to None.
# This may break inductor since a caller expects a tuple may get a item.
#
# Add a wrapper to maintain the same behavior for inductor.
# Maybe we should have own implementation of this function?
return triton_do_bench, (
"quantiles"
if inspect.signature(triton_do_bench).parameters.get("quantiles")
is not None
else "percentiles"
)
triton_do_bench, quantile_field_name = load_triton()
if quantile_field_name not in kwargs:
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
return triton_do_bench(*args, **kwargs)[0]
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = os.path.join(
tempfile.gettempdir(),
"torchinductor_" + sanitized_username,
)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
HAS_COLORAMA = True
try:
import colorama
except ImportError:
HAS_COLORAMA = False
def _color_text(msg, color):
if not HAS_COLORAMA:
return msg
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
def green_text(msg):
return _color_text(msg, "green")
def yellow_text(msg):
return _color_text(msg, "yellow")
def red_text(msg):
return _color_text(msg, "red")
def blue_text(msg):
return _color_text(msg, "blue")