Files
pytorch/torch/_inductor/runtime/runtime_utils.py
Sam Larsen ff17d2b83e [easy][logging] Remove dynamo_timed fwd_only param (#140993)
Summary: It's ignored; remove it

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140993
Approved by: https://github.com/ezyang
2024-11-20 02:31:51 +00:00

167 lines
4.3 KiB
Python

# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import functools
import operator
import torch
from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
cache_dir,
default_cache_dir,
triton_cache_dir,
)
def conditional_product(*args):
return functools.reduce(operator.mul, [x for x in args if x])
def ceildiv(numer: int, denom: int) -> int:
return -(numer // -denom)
def is_power_of_2(n: int) -> bool:
"""Returns whether n = 2 ** m for some integer m."""
return n > 0 and n & n - 1 == 0
def next_power_of_2(n: int) -> int:
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
"""
Return the total number of bytes the arguments of tensor type takes.
For in/out args, tensor sizes are counted twice: once for reading and
once for writing.
The first num_in_out_args arguments are in out tensors.
"""
return sum(
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
for i, arg in enumerate(args)
if isinstance(arg, torch.Tensor)
)
def triton_config_to_hashable(cfg):
"""
Convert triton config to a tuple that can uniquely identify it. We can use
the return value as a dictionary key.
"""
items = sorted(cfg.kwargs.items())
items.append(("num_warps", cfg.num_warps))
items.append(("num_stages", cfg.num_stages))
return tuple(items)
def validate_triton_config(cfg):
# [Note: Triton pre_hook in inductor]
# pre-hook is a lambda function, which we don't attempt to serialize.
# right now, if a pre-hook is attached to the config, it will not be saved;
# and then it won't be used when the config is loaded from cache.
# So we assert - if we do get a pre_hook, it might get ignored after caching.
assert (
getattr(cfg, "pre_hook", None) is None
), "triton configs with pre_hooks not supported"
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
slow = ms > 0.012 and gb_per_s < 650
return red_text(info_str) if color and slow else info_str
def get_max_y_grid():
return 65535
try:
import colorama
HAS_COLORAMA = True
except ModuleNotFoundError:
HAS_COLORAMA = False
colorama = None # type: ignore[assignment]
def _color_text(msg, color):
if not HAS_COLORAMA:
return msg
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
def green_text(msg):
return _color_text(msg, "green")
def yellow_text(msg):
return _color_text(msg, "yellow")
def red_text(msg):
return _color_text(msg, "red")
def blue_text(msg):
return _color_text(msg, "blue")
def get_first_attr(obj, *attrs):
"""
Return the first available attribute or throw an exception if none is present.
"""
for attr in attrs:
if hasattr(obj, attr):
return getattr(obj, attr)
raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
try:
dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type]
except AttributeError: # Compile workers only have a mock version of torch
@contextlib.contextmanager
def dynamo_timed(
key,
phase_name=None,
metadata=None,
dynamo_compile_column_us=None,
):
yield
def triton_hash_to_path_key(key):
# In early versions of Triton, the hash is directly used in the path name.
# Later, the hash is converted to base64 before being used in the path name.
# Later, the base64 convertion was replaced to the base32
#
# This code tries to import _base64 and falls back to _base32 if _base64 is unavailable.
#
# To handle this, try to import the to-base64-conversion function.
# If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly.
try:
from triton.runtime.cache import _base64
return _base64(key)
except Exception as e:
try:
from triton.runtime.cache import _base32
return _base32(key)
except Exception as e:
return key