mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit cb15c1515778499ae801dcf67d55c8bdab4724ef. Reverted https://github.com/pytorch/pytorch/pull/139849 on behalf of https://github.com/kit1980 due to Breaking an internal tests + there is a bug according to the author ([comment](https://github.com/pytorch/pytorch/pull/139849#issuecomment-2474459094))
400 lines
12 KiB
Python
400 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import atexit
|
|
import collections
|
|
import dataclasses
|
|
import functools
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import typing
|
|
from abc import abstractmethod
|
|
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
|
|
from typing_extensions import override, TypeAlias
|
|
|
|
from torch._dynamo.utils import dynamo_timed
|
|
from torch._inductor import config
|
|
|
|
|
|
try:
|
|
import redis
|
|
except ImportError:
|
|
redis = None # type: ignore[assignment]
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
if config.is_fbcode():
|
|
from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found]
|
|
Sample as Sample_,
|
|
)
|
|
|
|
Sample: TypeAlias = Sample_
|
|
else:
|
|
Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_U = TypeVar("_U")
|
|
|
|
|
|
remote_fx_cache_get_timed = functools.partial(
|
|
dynamo_timed,
|
|
"FbRemoteFxGraphCache.get",
|
|
phase_name="remote_fx_graph_cache_get",
|
|
log_pt2_compile_event=False,
|
|
fwd_only=False,
|
|
)
|
|
remote_fx_cache_put_timed = functools.partial(
|
|
dynamo_timed,
|
|
"FbRemoteFxGraphCache.put",
|
|
phase_name="remote_fx_graph_cache_put",
|
|
log_pt2_compile_event=False,
|
|
fwd_only=False,
|
|
)
|
|
|
|
|
|
class RemoteCacheBackend(Generic[_T]):
|
|
"""
|
|
A backend implementation for accessing a remote/distributed cache. Only
|
|
works with bytes in/out. For structured data use a RemoteCache.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._name = f"backend:{type(self).__name__}"
|
|
|
|
@abstractmethod
|
|
def _get(self, key: str) -> Optional[_T]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _put(self, key: str, data: _T) -> None:
|
|
pass
|
|
|
|
def get(self, key: str) -> Optional[_T]:
|
|
try:
|
|
value = self._get(key)
|
|
cache_stats.get(self._name, value)
|
|
except Exception:
|
|
cache_stats.exception(self._name)
|
|
raise
|
|
return value
|
|
|
|
def put(self, key: str, data: _T) -> None:
|
|
try:
|
|
self._put(key, data)
|
|
cache_stats.put(self._name)
|
|
except Exception:
|
|
cache_stats.exception(self._name)
|
|
raise
|
|
|
|
|
|
# Serde that encodes from _T to _U and decodes from _U to _T.
|
|
class RemoteCacheSerde(Generic[_T, _U]):
|
|
@abstractmethod
|
|
def encode(self, data: _T) -> _U:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def decode(self, data: _U) -> _T:
|
|
pass
|
|
|
|
|
|
JsonDataTy = Optional[
|
|
Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
|
|
]
|
|
|
|
|
|
class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]):
|
|
def encode(self, data: JsonDataTy) -> bytes:
|
|
return bytes(json.dumps(data), "ascii")
|
|
|
|
def decode(self, data: bytes) -> JsonDataTy:
|
|
return json.loads(data)
|
|
|
|
|
|
class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]):
|
|
def encode(self, data: _T) -> _T:
|
|
return data
|
|
|
|
def decode(self, data: _T) -> _T:
|
|
return data
|
|
|
|
|
|
# This class is the top of a RemoteCache. A RemoteCache is fundamentally made of
|
|
# three parts:
|
|
#
|
|
# 1. The controller (this class).
|
|
# 2. A serializer/deserializer (instance of RemoteCacheSerde).
|
|
# 3. A backend (instance of RemoteCacheBackend).
|
|
#
|
|
# To write (`put`), the RemoteCache takes data, uses the RemoteCacheSerde to
|
|
# convert it for the backend and passes it to the backend.
|
|
#
|
|
# Conversly when reading (`get`), the RemoteCache takes data from the backend,
|
|
# uses the RemoteCacheSerde to convert it and returns it.
|
|
#
|
|
# The RemoteCacheBackend is generic on _U - which is the type of data the
|
|
# backend can directly cache (usually `bytes`).
|
|
#
|
|
# The RemoteCacheSerde is responsible for converting between _T (the type of
|
|
# data the RemoteCache accepts in `put` and returns in `get`) and _U.
|
|
#
|
|
# When instantiating a RemoteCache you should override, not directly create a
|
|
# RemoteCache. The reason is that when logging cache use (`TORCH_LOGS=cache`) we
|
|
# use the concrete type of the RemoteCache as the reported cache. See
|
|
# RemoteFxGraphCache below as an example.
|
|
class RemoteCache(Generic[_T]):
|
|
backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None
|
|
|
|
def __init__(
|
|
self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U]
|
|
) -> None:
|
|
# Support for testing to mock out the backend on a class-by-class basis.
|
|
if (override_cls := self.__class__.backend_override_cls) is not None:
|
|
self.backend = override_cls()
|
|
else:
|
|
self.backend = backend
|
|
self.serde = serde
|
|
|
|
# See if the cache contains `key`. Returns `None` if the value is not
|
|
# present in the cache.
|
|
def get(self, key: str) -> Optional[_T]:
|
|
sample = self._create_sample()
|
|
try:
|
|
result = self._get(key, sample)
|
|
cache_stats.get(type(self).__name__, result)
|
|
except Exception:
|
|
cache_stats.exception(type(self).__name__)
|
|
raise
|
|
self._log_sample(sample)
|
|
return result
|
|
|
|
# Add `value` to the cache with the key `key`. Note that `None` is not a
|
|
# valid value even if _T supports it (because you can't tell the difference
|
|
# between `None` and a missing cache entry).
|
|
def put(self, key: str, value: _T) -> None:
|
|
assert value is not None
|
|
sample = self._create_sample()
|
|
try:
|
|
self._put(key, value, sample)
|
|
cache_stats.put(type(self).__name__)
|
|
except Exception:
|
|
cache_stats.exception(type(self).__name__)
|
|
raise
|
|
self._log_sample(sample)
|
|
|
|
# Used to convert data from the cache into structured data.
|
|
def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]
|
|
return self.serde.decode(data) # type: ignore[arg-type]
|
|
|
|
# Used to convert structured data into data for the cache.
|
|
def _encode(self, value: _T, sample: Optional[Sample]) -> object: # returns _U
|
|
return self.serde.encode(value)
|
|
|
|
# Get structured data from the cache.
|
|
# Separate from `get` so that it can be overridden.
|
|
def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]:
|
|
if data := self._backend_get(key):
|
|
return self._decode(data, sample)
|
|
return None
|
|
|
|
# Get unstructured data from the cache.
|
|
# Separate from `get` so that it can be overridden.
|
|
# Returns _U - but we aren't actually generic on _U
|
|
def _backend_get(self, key: str) -> object:
|
|
return self.backend.get(key)
|
|
|
|
# Put structured data into the cache.
|
|
# Separate from `put` so that it can be overridden.
|
|
def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None:
|
|
data = self._encode(value, sample)
|
|
self._backend_put(key, data)
|
|
|
|
# Put unstructured data into the cache.
|
|
# Separate from `put` so that it can be overridden.
|
|
# Takes data: _U - but we aren't actually generic on _U
|
|
def _backend_put(self, key: str, data: object) -> None:
|
|
self.backend.put(key, data)
|
|
|
|
# Create a logging Sample - used with internal loggers to monitor cache
|
|
# effectiveness.
|
|
def _create_sample(self) -> Optional[Sample]:
|
|
return None
|
|
|
|
# Write the logging Sample to the logger.
|
|
def _log_sample(self, sample: Optional[Sample]) -> None:
|
|
pass
|
|
|
|
|
|
class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]):
|
|
"""
|
|
A Redis implementation of a remote/distributed cache.
|
|
"""
|
|
|
|
_key_fmt: str
|
|
_redis: Optional[redis.Redis] = None
|
|
|
|
def __init__(self, cache_id: str) -> None:
|
|
super().__init__()
|
|
if not redis:
|
|
# We had trouble importing redis - just skip init.
|
|
return
|
|
|
|
self._key_fmt = f"pt2:{cache_id}:{{key}}"
|
|
self._redis = redis.Redis(
|
|
host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"),
|
|
port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)),
|
|
)
|
|
|
|
def __get_key(self, key: str) -> str:
|
|
return self._key_fmt.format(key=key)
|
|
|
|
@override
|
|
def _get(self, key: str) -> Optional[bytes]:
|
|
if not self._redis:
|
|
# Either redis wasn't found or we already had some trouble...
|
|
return None
|
|
|
|
try:
|
|
value = self._redis.get(self.__get_key(key))
|
|
except redis.exceptions.ConnectionError:
|
|
# Redis is lazy and doesn't actually attempt to connect until the
|
|
# first use. Mark is as unavailable now.
|
|
self._redis = None
|
|
return None
|
|
|
|
# In theory redis.get() can return an Awaitable as well...
|
|
assert value is None or isinstance(value, bytes)
|
|
return value
|
|
|
|
@override
|
|
def _put(self, key: str, data: bytes) -> None:
|
|
if not self._redis:
|
|
# Either redis wasn't found or we already had some trouble...
|
|
return
|
|
|
|
try:
|
|
self._redis.set(self.__get_key(key), data)
|
|
except redis.exceptions.ConnectionError:
|
|
# Redis is lazy and doesn't actually attempt to connect until the
|
|
# first use. Mark is as unavailable now.
|
|
self._redis = None
|
|
|
|
|
|
class RedisRemoteCache(RemoteCache[JsonDataTy]):
|
|
def __init__(self, key: str) -> None:
|
|
# Special test handling: If we're just going to override the backend
|
|
# anyway don't require redis
|
|
if self.__class__.backend_override_cls:
|
|
# This is totally bogus but it works for now...
|
|
backend = typing.cast(RemoteCacheBackend[bytes], None)
|
|
else:
|
|
backend = RedisRemoteCacheBackend(key)
|
|
serde = RemoteCacheJsonSerde()
|
|
super().__init__(backend, serde)
|
|
|
|
|
|
class RemoteAutotuneCache(RedisRemoteCache):
|
|
pass
|
|
|
|
|
|
class RemoteBundledAutotuneCache(RedisRemoteCache):
|
|
pass
|
|
|
|
|
|
class RemoteFxGraphCache(RedisRemoteCache):
|
|
pass
|
|
|
|
|
|
class RemoteAOTAutogradCache(RedisRemoteCache):
|
|
pass
|
|
|
|
|
|
class RemoteDynamoPGOCache(RedisRemoteCache):
|
|
pass
|
|
|
|
|
|
def create_cache(
|
|
key: str,
|
|
is_fbcode: bool,
|
|
fb_cache_cls: str,
|
|
oss_cache_cls: str,
|
|
) -> Optional[RemoteCache[JsonDataTy]]:
|
|
try:
|
|
if is_fbcode:
|
|
import torch._inductor.fb.remote_cache
|
|
|
|
cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls)
|
|
return cache_cls(key)
|
|
else:
|
|
this_module = sys.modules[__name__]
|
|
|
|
cache_cls = getattr(this_module, oss_cache_cls)
|
|
return cache_cls(key)
|
|
|
|
except Exception:
|
|
log.warning("Unable to create a remote cache", exc_info=True)
|
|
return None
|
|
|
|
|
|
# Some simple stat capture
|
|
@dataclasses.dataclass
|
|
class _CacheStat:
|
|
miss: int = 0
|
|
hit: int = 0
|
|
put: int = 0
|
|
exception: int = 0
|
|
|
|
def __str__(self) -> str:
|
|
return f"{{hit: {self.hit}, miss: {self.miss}, put: {self.put}, exception: {self.exception}}}"
|
|
|
|
|
|
class _CacheStats:
|
|
_stats: Dict[str, _CacheStat]
|
|
|
|
def __init__(self) -> None:
|
|
self._stats = collections.defaultdict(_CacheStat)
|
|
|
|
def miss(self, name: str, count: int = 1) -> None:
|
|
self._stats[name].miss += count
|
|
|
|
def hit(self, name: str, count: int = 1) -> None:
|
|
self._stats[name].hit += count
|
|
|
|
def get(self, name: str, value: Optional[object]) -> None:
|
|
if value is None:
|
|
self.miss(name)
|
|
else:
|
|
self.hit(name)
|
|
|
|
def put(self, name: str, count: int = 1) -> None:
|
|
self._stats[name].put += count
|
|
|
|
def exception(self, name: str, count: int = 1) -> None:
|
|
self._stats[name].exception += count
|
|
|
|
|
|
cache_stats = _CacheStats()
|
|
|
|
|
|
@atexit.register
|
|
def dump_cache_stats() -> None:
|
|
if not log.isEnabledFor(logging.INFO):
|
|
return
|
|
|
|
import io
|
|
|
|
out = io.StringIO()
|
|
|
|
if not cache_stats._stats:
|
|
print(" None", file=out)
|
|
else:
|
|
print(file=out)
|
|
for k, v in sorted(cache_stats._stats.items()):
|
|
print(f" {k}: {v}", file=out)
|
|
|
|
log.info("Cache Metrics:%s", out.getvalue())
|