mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This reverts commit 43d78423ac224cce432bf34ed9627035169d5433. Reverted https://github.com/pytorch/pytorch/pull/165692 on behalf of https://github.com/seemethere due to This is causing merge conflicts when attempting to land internally, see D84890919 for more details ([comment](https://github.com/pytorch/pytorch/pull/165692#issuecomment-3416397240))
185 lines
4.9 KiB
Python
185 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
import operator
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
# NOTE: other files rely on the imports below
|
|
from torch._dynamo import callback as compilation_callback # noqa: F401
|
|
from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
|
|
cache_dir,
|
|
default_cache_dir,
|
|
triton_cache_dir,
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Hashable
|
|
|
|
from .triton_compat import Config
|
|
|
|
|
|
def conditional_product(*args: int) -> int:
|
|
return functools.reduce(operator.mul, [x for x in args if x])
|
|
|
|
|
|
def ceildiv(number: int, denom: int) -> int:
|
|
return -(number // -denom)
|
|
|
|
|
|
def is_power_of_2(n: int) -> bool:
|
|
"""Returns whether n = 2 ** m for some integer m."""
|
|
return n > 0 and n & n - 1 == 0
|
|
|
|
|
|
def next_power_of_2(n: int) -> int:
|
|
"""Return the smallest power of 2 greater than or equal to n"""
|
|
n -= 1
|
|
n |= n >> 1
|
|
n |= n >> 2
|
|
n |= n >> 4
|
|
n |= n >> 8
|
|
n |= n >> 16
|
|
n |= n >> 32
|
|
n += 1
|
|
return n
|
|
|
|
|
|
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
|
|
"""
|
|
Return the total number of bytes the arguments of tensor type takes.
|
|
|
|
For in/out args, tensor sizes are counted twice: once for reading and
|
|
once for writing.
|
|
|
|
The first num_in_out_args arguments are in out tensors.
|
|
"""
|
|
return sum(
|
|
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
|
|
for i, arg in enumerate(args)
|
|
if isinstance(arg, torch.Tensor)
|
|
)
|
|
|
|
|
|
def triton_config_to_hashable(cfg: Config) -> Hashable:
|
|
"""
|
|
Convert triton config to a tuple that can uniquely identify it. We can use
|
|
the return value as a dictionary key.
|
|
"""
|
|
items = sorted(cfg.kwargs.items())
|
|
items.append(("num_warps", cfg.num_warps))
|
|
items.append(("num_stages", cfg.num_stages))
|
|
return tuple(items)
|
|
|
|
|
|
def validate_triton_config(cfg: Config) -> None:
|
|
# [Note: Triton pre_hook in inductor]
|
|
# pre-hook is a lambda function, which we don't attempt to serialize.
|
|
# right now, if a pre-hook is attached to the config, it will not be saved;
|
|
# and then it won't be used when the config is loaded from cache.
|
|
# So we assert - if we do get a pre_hook, it might get ignored after caching.
|
|
assert getattr(cfg, "pre_hook", None) is None, (
|
|
"triton configs with pre_hooks not supported"
|
|
)
|
|
|
|
|
|
def create_bandwidth_info_str(
|
|
ms: float,
|
|
num_gb: float,
|
|
gb_per_s: float,
|
|
prefix: str = "",
|
|
suffix: str = "",
|
|
color: bool = True,
|
|
) -> str:
|
|
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
|
|
slow = ms > 0.012 and gb_per_s < 650
|
|
return red_text(info_str) if color and slow else info_str
|
|
|
|
|
|
def get_max_y_grid() -> int:
|
|
return 65535
|
|
|
|
|
|
try:
|
|
import colorama
|
|
|
|
HAS_COLORAMA = True
|
|
except ModuleNotFoundError:
|
|
HAS_COLORAMA = False
|
|
colorama = None # type: ignore[assignment]
|
|
|
|
|
|
if HAS_COLORAMA:
|
|
|
|
def _color_text(msg: str, color: str) -> str:
|
|
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
|
|
|
|
else:
|
|
|
|
def _color_text(msg: str, color: str) -> str:
|
|
return msg
|
|
|
|
|
|
def green_text(msg: str) -> str:
|
|
return _color_text(msg, "green")
|
|
|
|
|
|
def yellow_text(msg: str) -> str:
|
|
return _color_text(msg, "yellow")
|
|
|
|
|
|
def red_text(msg: str) -> str:
|
|
return _color_text(msg, "red")
|
|
|
|
|
|
def blue_text(msg: str) -> str:
|
|
return _color_text(msg, "blue")
|
|
|
|
|
|
def get_first_attr(obj: Any, *attrs: str) -> Any:
|
|
"""
|
|
Return the first available attribute or throw an exception if none is present.
|
|
"""
|
|
for attr in attrs:
|
|
if hasattr(obj, attr):
|
|
return getattr(obj, attr)
|
|
|
|
raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
|
|
|
|
|
|
dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type]
|
|
|
|
|
|
def triton_hash_to_path_key(key: str) -> str:
|
|
# In early versions of Triton, the hash is directly used in the path name.
|
|
# Later, the hash is converted to base64 before being used in the path name.
|
|
# Later, the base64 conversion was replaced to the base32
|
|
#
|
|
# This code tries to import _base64 and falls back to _base32 if _base64 is unavailable.
|
|
#
|
|
# To handle this, try to import the to-base64-conversion function.
|
|
# If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly.
|
|
try:
|
|
from triton.runtime.cache import _base64
|
|
|
|
return _base64(key)
|
|
except Exception:
|
|
try:
|
|
from triton.runtime.cache import _base32
|
|
|
|
return _base32(key)
|
|
except Exception:
|
|
return key
|
|
|
|
|
|
def compile_mps_shader(source: str) -> Any:
|
|
"""
|
|
Compiles shader source but raise more actionable error message when needed
|
|
"""
|
|
try:
|
|
return torch.mps.compile_shader(source)
|
|
except SyntaxError as err:
|
|
raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
|