mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198 Approved by: https://github.com/bobrenjc93
495 lines
17 KiB
Python
495 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import os.path
|
|
import re
|
|
from typing import Optional, TYPE_CHECKING
|
|
from typing_extensions import override
|
|
|
|
import torch
|
|
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
|
|
from torch.utils._triton import has_triton
|
|
|
|
from ..remote_cache import (
|
|
create_cache,
|
|
JsonDataTy,
|
|
RemoteCache,
|
|
RemoteCacheBackend,
|
|
RemoteCacheJsonSerde,
|
|
)
|
|
from .triton_compat import Config
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ..remote_cache import Sample
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
_InductorMetaTy = dict[str, object]
|
|
|
|
|
|
def inductor_meta_from_config() -> _InductorMetaTy:
|
|
from torch._inductor import config
|
|
|
|
backend_hash = None
|
|
if has_triton():
|
|
try:
|
|
backend_hash = torch.utils._triton.triton_hash_with_backend()
|
|
except RuntimeError:
|
|
# This can get the error:
|
|
# RuntimeError: 0 active drivers ([]). There should only be one.
|
|
pass
|
|
|
|
is_hip = None
|
|
if torch.version.hip is not None:
|
|
is_hip = True
|
|
|
|
return {
|
|
"autotune_local_cache": config.autotune_local_cache,
|
|
"autotune_remote_cache": config.autotune_remote_cache,
|
|
"backend_hash": backend_hash,
|
|
"bundled_autotune_remote_cache": config.bundled_autotune_remote_cache,
|
|
"coordinate_descent_tuning": config.coordinate_descent_tuning,
|
|
"is_fbcode": config.is_fbcode(),
|
|
"is_hip": is_hip,
|
|
}
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AutotuneCache:
|
|
configs_hash: str
|
|
local_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None
|
|
remote_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None
|
|
|
|
# Create a AutotuneCache. Returns None if none of the caches can be used.
|
|
@staticmethod
|
|
def create(
|
|
inductor_meta: _InductorMetaTy, filename: str, configs_hash: str
|
|
) -> Optional[AutotuneCache]:
|
|
cache = AutotuneCache(configs_hash)
|
|
key = AutotuneCache._prepare_key(filename)
|
|
cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key)
|
|
cache._setup_remote_autotune_cache(inductor_meta, key)
|
|
if cache.local_cache or cache.remote_cache:
|
|
return cache
|
|
else:
|
|
return None
|
|
|
|
@staticmethod
|
|
def _prepare_key(filename: str) -> str:
|
|
from torch.compiler import config as cconfig
|
|
|
|
# base of filename is already sha256 hash the source contents
|
|
key = f"{os.path.basename(filename)}:{cconfig.cache_key_tag}"
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
|
|
# Read the best config options from the most local cache and return it.
|
|
def _read(self) -> Optional[dict[str, JsonDataTy]]:
|
|
if local_cache := self.local_cache:
|
|
cache, key = local_cache
|
|
if best_config := cache.get(key):
|
|
if isinstance(best_config, dict):
|
|
return best_config
|
|
|
|
if remote_cache := self.remote_cache:
|
|
cache, key = remote_cache
|
|
if best_config := cache.get(key):
|
|
if isinstance(best_config, dict):
|
|
return best_config
|
|
|
|
return None
|
|
|
|
# Read the best config options from the most local cache and figure out
|
|
# which `configs` represents that option.
|
|
def read_best(
|
|
self, inductor_meta: _InductorMetaTy, configs: list[Config]
|
|
) -> Optional[Config]:
|
|
if best := self._read():
|
|
return _load_cached_autotuning(
|
|
best, self.configs_hash, configs, inductor_meta
|
|
)
|
|
return None
|
|
|
|
# Set up local filesystem caching information
|
|
def _setup_local_cache(
|
|
self, inductor_meta: _InductorMetaTy, dirname: str, cache_key: str
|
|
) -> None:
|
|
if not inductor_meta.get("autotune_local_cache", True):
|
|
return
|
|
|
|
cache_filename = f"{dirname}/{cache_key}.best_config"
|
|
local_cache = LocalAutotuneCache()
|
|
self.local_cache = (local_cache, cache_filename)
|
|
|
|
# Set up remote caching information
|
|
def _setup_remote_autotune_cache(
|
|
self, inductor_meta: _InductorMetaTy, cache_key: str
|
|
) -> None:
|
|
if not _should_use_remote_autotune_cache(inductor_meta):
|
|
return
|
|
|
|
if (backend_hash := inductor_meta.get("backend_hash", None)) is None:
|
|
log.debug(
|
|
"backend_hash is not passed on the inductor_meta, unable to use autotune remote cache"
|
|
)
|
|
return
|
|
assert isinstance(backend_hash, str)
|
|
|
|
is_fbcode = bool(inductor_meta.get("is_fbcode", False))
|
|
|
|
salt = "autotune-best-config-v2"
|
|
key = backend_hash + self.configs_hash + salt
|
|
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
|
|
remote_cache = create_cache(
|
|
key,
|
|
is_fbcode,
|
|
"FbRemoteAutotuneCache",
|
|
"RemoteAutotuneCache",
|
|
)
|
|
if not remote_cache:
|
|
return
|
|
|
|
self.remote_cache = (remote_cache, cache_key)
|
|
|
|
# Save the config in the caches
|
|
def save(
|
|
self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False
|
|
) -> None:
|
|
data = {
|
|
**config.kwargs,
|
|
"num_warps": config.num_warps,
|
|
"num_stages": config.num_stages,
|
|
"configs_hash": self.configs_hash,
|
|
"found_by_coordesc": found_by_coordesc,
|
|
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
|
|
}
|
|
|
|
if local_cache := self.local_cache:
|
|
cache, key = local_cache
|
|
cache.put(key, data)
|
|
AutotuneCacheBundler.put(key, data)
|
|
CacheArtifactManager.record_artifact(
|
|
CacheArtifactType.AUTOTUNE, os.path.basename(key), data
|
|
)
|
|
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
type_str = "coordesc" if found_by_coordesc else "heuristic"
|
|
log.debug("Save %s tuning result to %s", type_str, key)
|
|
|
|
if remote_cache := self.remote_cache:
|
|
cache, key = remote_cache
|
|
cache.put(key, data)
|
|
|
|
|
|
class _AutotuneCacheBundlerImpl:
|
|
"""
|
|
Caches a set of LocalAutotuneCacheBackend entries together in a single
|
|
cache.
|
|
"""
|
|
|
|
_key: str
|
|
_cache: RemoteCache[JsonDataTy]
|
|
|
|
# All known entries from LocalAutotuneCache.put()
|
|
_entries: dict[str, JsonDataTy]
|
|
|
|
def end_compile(self) -> None:
|
|
# TODO: Do we need to compute time_taken_ms and encode that somehow?
|
|
if self._entries:
|
|
self._cache.put(self._key, self._entries)
|
|
|
|
def put(self, basename: str, data: JsonDataTy) -> None:
|
|
# Do we need to worry about duplicates? We only have a single local fs
|
|
# entry - so probably not.
|
|
self._entries[basename] = data
|
|
|
|
def __init__(self, key: str, cache: RemoteCache[JsonDataTy]) -> None:
|
|
self._key = key
|
|
self._cache = cache
|
|
self._entries = {}
|
|
|
|
def sync(self) -> None:
|
|
# We don't currently use this - but we could async load starting at
|
|
# `begin_compile` and wait for the load to be finished here.
|
|
pass
|
|
|
|
@classmethod
|
|
def _should_use_bundled_autotune_remote_cache(
|
|
cls, inductor_meta: _InductorMetaTy
|
|
) -> bool:
|
|
# The bundled autotune cache is only available if you've also got local
|
|
# caching enabled (because we feed the bundled data to the local cache).
|
|
if not inductor_meta.get("autotune_local_cache", True):
|
|
return False
|
|
|
|
# Check if the we're enabled via config
|
|
if (
|
|
bundled_autotune_remote_cache := inductor_meta.get(
|
|
"bundled_autotune_remote_cache"
|
|
)
|
|
) is not None:
|
|
return bool(bundled_autotune_remote_cache)
|
|
|
|
if not cls._get_is_fbcode(inductor_meta):
|
|
return False
|
|
if torch._utils_internal.is_fb_unit_test():
|
|
return False
|
|
if inductor_meta.get("is_hip"):
|
|
return False
|
|
|
|
try:
|
|
from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
|
|
except ModuleNotFoundError:
|
|
return False
|
|
|
|
jk = torch._utils_internal.justknobs_getval_int(
|
|
"pytorch/remote_cache:bundled_autotune_remote_cache_version"
|
|
)
|
|
return REMOTE_CACHE_VERSION >= jk
|
|
|
|
def _load_cache(self) -> bool:
|
|
from torch._inductor import codecache
|
|
|
|
# The single key is defined on construction of the cache.
|
|
entries = self._cache.get(self._key)
|
|
if entries is None or not isinstance(entries, dict):
|
|
# We couldn't load the cache - so mark _entries as non-None so we
|
|
# store local cache values.
|
|
return False
|
|
|
|
# Go through the entries we got from the cache and save them locally.
|
|
time_saved_ns = 0
|
|
for basename, data in entries.items():
|
|
# Reconstruct the final filename (see put())
|
|
root, ext = _splitext_nodot(basename)
|
|
_, _, filename = codecache.get_path(root, ext)
|
|
if isinstance(data, dict) and (tsns := data.get("time_saved_ns")):
|
|
time_saved_ns += int(tsns) # type: ignore[arg-type]
|
|
local_cache = LocalAutotuneCache()
|
|
local_cache.put(filename, data)
|
|
|
|
codecache.add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def _get_is_fbcode(inductor_meta: _InductorMetaTy) -> bool:
|
|
return bool(inductor_meta.get("is_fbcode", False))
|
|
|
|
@staticmethod
|
|
def _get_backend_hash(inductor_meta: _InductorMetaTy) -> str:
|
|
backend_hash = inductor_meta["backend_hash"]
|
|
assert isinstance(backend_hash, str)
|
|
return backend_hash
|
|
|
|
|
|
class AutotuneCacheBundler:
|
|
_bundler: Optional[_AutotuneCacheBundlerImpl] = None
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
# Call this before we start any autotune computation for an inductor python
|
|
# file. On a cache hit it copies the individual results into the local
|
|
# autotune caches.
|
|
@classmethod
|
|
def begin_compile(
|
|
cls,
|
|
inductor_meta: _InductorMetaTy,
|
|
*,
|
|
code: Optional[str] = None,
|
|
code_hash: Optional[str] = None,
|
|
) -> None:
|
|
assert cls._bundler is None
|
|
|
|
if code is not None:
|
|
assert code_hash is None, "Cannot specify both code and code_hash"
|
|
code_hash = _comment_stripped_hash(code)
|
|
assert code_hash is not None
|
|
|
|
if not _AutotuneCacheBundlerImpl._should_use_bundled_autotune_remote_cache(
|
|
inductor_meta
|
|
):
|
|
return
|
|
|
|
cache = create_cache(
|
|
"bundled-autotune-v1",
|
|
_AutotuneCacheBundlerImpl._get_is_fbcode(inductor_meta),
|
|
"FbRemoteBundledAutotuneCache",
|
|
"RemoteBundledAutotuneCache",
|
|
)
|
|
if not cache:
|
|
return
|
|
|
|
# We're starting a compilation phase. We have a cache key for the code
|
|
# we're compiling. We'll get the individual autotune bundles later (via
|
|
# self.put()). For now create the AutotuneCacheBundler and try to load
|
|
# from the cache.
|
|
|
|
salt = "bundled-autotune-best-configs-v1"
|
|
backend_hash = _AutotuneCacheBundlerImpl._get_backend_hash(inductor_meta)
|
|
# TODO: The autotune cache includes configs_hash in the key. The problem
|
|
# is that the configs_hash includes info from the individual pointwise()
|
|
# calls (size_hints, for example) which we can't know yet. I *think*
|
|
# that info is basically present in the `code_hash` (since it's a
|
|
# parameter to the pointwise decorator) - but is there other info we
|
|
# need to include from inductor_meta?
|
|
key = code_hash + backend_hash + salt
|
|
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
|
|
bundler = _AutotuneCacheBundlerImpl(key, cache)
|
|
if not bundler._load_cache():
|
|
# We couldn't load from the cache - so save the data so we can store
|
|
# the saved autotunes.
|
|
cls._bundler = bundler
|
|
|
|
# If we get a cache hit don't bother saving any of the individual
|
|
# autotune results.
|
|
|
|
# Call this after all individual autotune results are finished for a
|
|
# inductor python file. If we gathered any individual results then we bundle
|
|
# those and put it into the cache.
|
|
@classmethod
|
|
def end_compile(cls) -> None:
|
|
if bundler := cls._bundler:
|
|
cls._bundler = None
|
|
bundler.end_compile()
|
|
|
|
@classmethod
|
|
def sync(cls) -> None:
|
|
if bundler := cls._bundler:
|
|
bundler.sync()
|
|
|
|
@classmethod
|
|
def put(cls, filename: str, data: JsonDataTy) -> None:
|
|
if bundler := cls._bundler:
|
|
# The filename comes in as something like
|
|
# "/tmp/tmp{random}/{aa}/{basename}.py" (where aa is
|
|
# basename[1:3]). Strip it down and make sure that it looks like a path
|
|
# we could reconstruct (because it's possible for the caller to
|
|
# customize the path).
|
|
basename = os.path.basename(filename)
|
|
|
|
# TODO: check cache_dir() vs filename, then strip dirname
|
|
bundler.put(basename, data)
|
|
|
|
|
|
# Remove the comments from the code (which include things like run ids and file
|
|
# paths) and then hash the result.
|
|
def _comment_stripped_hash(code: str) -> str:
|
|
code = re.sub(r"#.*$", "", code, count=0, flags=re.MULTILINE)
|
|
return torch._inductor.codecache.code_hash(code)
|
|
|
|
|
|
def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool:
|
|
if (config := inductor_meta.get("autotune_remote_cache")) is not None:
|
|
return bool(config)
|
|
if not inductor_meta.get("is_fbcode"):
|
|
return False
|
|
if torch._utils_internal.is_fb_unit_test():
|
|
return False
|
|
if inductor_meta.get("is_hip"):
|
|
return False
|
|
|
|
try:
|
|
from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
|
|
except ModuleNotFoundError:
|
|
return False
|
|
|
|
return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
|
|
"pytorch/remote_cache:autotune_memcache_version"
|
|
)
|
|
|
|
|
|
def _load_cached_autotuning(
|
|
best_config: dict[str, JsonDataTy],
|
|
configs_hash: str,
|
|
configs: list[Config],
|
|
inductor_meta: _InductorMetaTy,
|
|
) -> Optional[Config]:
|
|
if best_config is None:
|
|
return None
|
|
if best_config.pop("configs_hash", None) != configs_hash:
|
|
return None
|
|
|
|
# Remove time taken for comparison
|
|
best_config.pop("time_taken_ms", None)
|
|
|
|
if inductor_meta.get("coordinate_descent_tuning") and best_config.pop(
|
|
"found_by_coordesc", False
|
|
):
|
|
num_warps = best_config.pop("num_warps")
|
|
num_stages = best_config.pop("num_stages")
|
|
triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
|
|
triton_config.found_by_coordesc = True
|
|
return triton_config
|
|
|
|
matching_configs = [
|
|
cfg
|
|
for cfg in configs
|
|
if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
|
|
and cfg.num_warps == best_config.get("num_warps")
|
|
and cfg.num_stages == best_config.get("num_stages")
|
|
]
|
|
if len(matching_configs) != 1:
|
|
return None
|
|
|
|
return matching_configs[0]
|
|
|
|
|
|
class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]):
|
|
@override
|
|
def _get(self, key: str) -> Optional[bytes]:
|
|
try:
|
|
with open(key, "rb") as fd:
|
|
return fd.read()
|
|
except FileNotFoundError:
|
|
return None
|
|
|
|
@override
|
|
def _put(self, key: str, data: bytes) -> None:
|
|
os.makedirs(os.path.dirname(key), exist_ok=True)
|
|
with open(key, "wb") as fd:
|
|
fd.write(data)
|
|
|
|
|
|
class LocalAutotuneCache(RemoteCache[JsonDataTy]):
|
|
def __init__(self) -> None:
|
|
backend = _LocalAutotuneCacheBackend()
|
|
serde = RemoteCacheJsonSerde()
|
|
super().__init__(backend, serde)
|
|
|
|
@override
|
|
def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]:
|
|
AutotuneCacheBundler.sync()
|
|
result = super()._get(key, sample)
|
|
if result is not None:
|
|
assert isinstance(result, dict)
|
|
# What? Why are we doing a put() here? Imagine we have a new model
|
|
# that reuses some existing kernels that have already been
|
|
# compiled. If we didn't do a `put` here (on cache hit) then the new
|
|
# model would only bundle *newly* compiled kernels, not existing
|
|
# kernels that were already compiled and cached.
|
|
AutotuneCacheBundler.put(key, result)
|
|
CacheArtifactManager.record_artifact(
|
|
CacheArtifactType.AUTOTUNE, os.path.basename(key), result
|
|
)
|
|
return result
|
|
|
|
@override
|
|
def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None:
|
|
AutotuneCacheBundler.put(key, value)
|
|
super()._put(key, value, sample)
|
|
|
|
|
|
def _splitext_nodot(basename: str) -> tuple[str, str]:
|
|
root, ext = os.path.splitext(basename)
|
|
if ext:
|
|
ext = ext[1:]
|
|
return root, ext
|