Files
pytorch/torch/_inductor/mock_cache.py
2025-01-21 21:04:33 +00:00

274 lines
8.4 KiB
Python

# mypy: ignore-errors
from __future__ import annotations
import contextlib
import dataclasses
import sys
import threading
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing_extensions import override, Self
from unittest.mock import patch
from torch._inductor import config
from torch._inductor.remote_cache import RemoteCacheBackend
if TYPE_CHECKING:
from types import TracebackType
@dataclasses.dataclass
class Stats:
num_put: int = 0
num_get_hit: int = 0
num_get_miss: int = 0
def __iadd__(self, other: Stats) -> Self:
self.num_put += other.num_put
self.num_get_hit += other.num_get_hit
self.num_get_miss += other.num_get_miss
return self
def reset(self) -> None:
self.num_put = 0
self.num_get_hit = 0
self.num_get_miss = 0
def __str__(self) -> str:
return "".join(
(
f"puts: {self.num_put}, ",
f"misses: {self.num_get_miss}, ",
f"hits: {self.num_get_hit}, ",
)
)
def __eq__(self, other: object) -> bool:
# Dataclass's default __eq__ checks that the types are the same so can't
# be used with _GlobalItemStats.
return (
isinstance(other, (Stats, _GlobalItemStats))
and self.num_put == other.num_put
and self.num_get_hit == other.num_get_hit
and self.num_get_miss == other.num_get_miss
)
class _GlobalItemStats(Stats):
cache: dict[str, object]
def __init__(self) -> None:
super().__init__()
self.cache = {}
def reset(self) -> None:
super().reset()
self.cache = {}
# The cache states are thread-local so if we're running multiple tests at once
# they won't cross contaminate. However - it needs to be "global" because we
# allow code to create new cache clients which refer to the same cache (because
# it's a remote cache).
class _GlobalStats(threading.local):
def __init__(self) -> None:
self.autotune_local = _GlobalItemStats()
self.autotune_remote = _GlobalItemStats()
self.bundled_autotune = _GlobalItemStats()
self.fx_graph = _GlobalItemStats()
self.triton = _GlobalItemStats()
self.aot_autograd = _GlobalItemStats()
self.dynamo_pgo = _GlobalItemStats()
def reset(self) -> None:
self.autotune_local.reset()
self.autotune_remote.reset()
self.bundled_autotune.reset()
self.fx_graph.reset()
self.triton.reset()
self.aot_autograd.reset()
self.dynamo_pgo.reset()
def get_stat(self, name: str) -> _GlobalItemStats:
return getattr(self, name)
def report(self):
subs = (
("autotune_local", self.autotune_local),
("autotune_remote", self.autotune_remote),
("bundled_autotune", self.bundled_autotune),
("fx_graph", self.fx_graph),
("triton", self.triton),
("aot_autograd", self.aot_autograd),
("dynamo_pgo", self.dynamo_pgo),
)
print("Cache Stats:", file=sys.stderr)
for name, sub in subs:
print(f" {name}: {sub}", file=sys.stderr)
print("Cache Entries:", file=sys.stderr)
for name, sub in subs:
if sub.cache:
print(f" {name}:", file=sys.stderr)
for k, v in sorted(sub.cache.items()):
v = repr(v)
if len(v) > 100:
v = v[:100] + "..."
print(f" {k!r}: {v}", file=sys.stderr)
global_stats = _GlobalStats()
class MockBackend(RemoteCacheBackend[Any]):
def __init__(self, name: str) -> None:
self._name = name
@staticmethod
def with_name(name: str) -> Callable[[], MockBackend]:
def wrapper() -> MockBackend:
return MockBackend(name)
return wrapper
@override
def _get(self, key: str) -> Optional[Any]:
stat = global_stats.get_stat(self._name)
if key in stat.cache:
stat += Stats(num_get_hit=1)
return stat.cache.get(key)
else:
stat += Stats(num_get_miss=1)
return None
@override
def _put(self, key: str, data: Any) -> None:
stat = global_stats.get_stat(self._name)
stat += Stats(num_put=1)
stat.cache[key] = data
# List of configs for each cache
_CACHE_CONFIG_EN = (
"fx_graph_cache",
"fx_graph_remote_cache",
"autotune_local_cache",
"autotune_remote_cache",
"bundled_autotune_remote_cache",
)
class PatchCaches(contextlib.AbstractContextManager):
@classmethod
def setUp(cls):
# If this test is using PatchCaches then disable all the caches by
# default, letting the tests turn them on explicitly. This is because
# tests using PatchCaches will often want to check stats explicitly.
cls._savedCacheState = {}
for name in _CACHE_CONFIG_EN:
if hasattr(config, name):
cls._savedCacheState[name] = getattr(config, name)
setattr(config, name, False)
@classmethod
def tearDown(cls):
# Restore cache defaults
for name in _CACHE_CONFIG_EN:
delattr(config, name)
if name in cls._savedCacheState:
setattr(config, name, cls._savedCacheState[name])
def __init__(self) -> None:
self._stack = contextlib.ExitStack()
def __enter__(self) -> Self:
global_stats.reset()
self._stack.__enter__()
ctx = patch(
"torch._inductor.runtime.autotune_cache.LocalAutotuneCache.backend_override_cls",
MockBackend.with_name("autotune_local"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls",
MockBackend.with_name("autotune_remote"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.remote_cache.RemoteBundledAutotuneCache.backend_override_cls",
MockBackend.with_name("bundled_autotune"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls",
MockBackend.with_name("fx_graph"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.remote_cache.RemoteAOTAutogradCache.backend_override_cls",
MockBackend.with_name("aot_autograd"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls",
MockBackend.with_name("dynamo_pgo"),
)
self._stack.enter_context(ctx)
if config.is_fbcode():
ctx = patch(
"torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls",
MockBackend.with_name("autotune_remote"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.fb.remote_cache.FbRemoteBundledAutotuneCache.backend_override_cls",
MockBackend.with_name("bundled_autotune"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls",
MockBackend.with_name("fx_graph"),
)
self._stack.enter_context(ctx)
ctx = patch(
"triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls",
MockBackend.with_name("triton"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.fb.remote_cache.FbRemoteAOTAutogradCache.backend_override_cls",
MockBackend.with_name("aot_autograd"),
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls",
MockBackend.with_name("dynamo_pgo"),
)
self._stack.enter_context(ctx)
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self._stack.__exit__(exc_type, exc_value, traceback)