mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Previously: https://github.com/pytorch/pytorch/pull/138052 but the implementation is done from scratch, so I open a new PR. This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/139001 Approved by: https://github.com/oulgen
274 lines
8.4 KiB
Python
274 lines
8.4 KiB
Python
# mypy: ignore-errors
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import sys
|
|
import threading
|
|
from typing import Any, Callable, Dict, Optional, Type, TYPE_CHECKING
|
|
from typing_extensions import override, Self
|
|
from unittest.mock import patch
|
|
|
|
from torch._inductor import config
|
|
from torch._inductor.remote_cache import RemoteCacheBackend
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from types import TracebackType
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Stats:
|
|
num_put: int = 0
|
|
num_get_hit: int = 0
|
|
num_get_miss: int = 0
|
|
|
|
def __iadd__(self, other: Stats) -> Self:
|
|
self.num_put += other.num_put
|
|
self.num_get_hit += other.num_get_hit
|
|
self.num_get_miss += other.num_get_miss
|
|
return self
|
|
|
|
def reset(self) -> None:
|
|
self.num_put = 0
|
|
self.num_get_hit = 0
|
|
self.num_get_miss = 0
|
|
|
|
def __str__(self) -> str:
|
|
return "".join(
|
|
(
|
|
f"puts: {self.num_put}, ",
|
|
f"misses: {self.num_get_miss}, ",
|
|
f"hits: {self.num_get_hit}, ",
|
|
)
|
|
)
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
# Dataclass's default __eq__ checks that the types are the same so can't
|
|
# be used with _GlobalItemStats.
|
|
return (
|
|
isinstance(other, (Stats, _GlobalItemStats))
|
|
and self.num_put == other.num_put
|
|
and self.num_get_hit == other.num_get_hit
|
|
and self.num_get_miss == other.num_get_miss
|
|
)
|
|
|
|
|
|
class _GlobalItemStats(Stats):
|
|
cache: Dict[str, object]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.cache = {}
|
|
|
|
def reset(self) -> None:
|
|
super().reset()
|
|
self.cache = {}
|
|
|
|
|
|
# The cache states are thread-local so if we're running multiple tests at once
|
|
# they won't cross contaminate. However - it needs to be "global" because we
|
|
# allow code to create new cache clients which refer to the same cache (because
|
|
# it's a remote cache).
|
|
|
|
|
|
class _GlobalStats(threading.local):
|
|
def __init__(self) -> None:
|
|
self.autotune_local = _GlobalItemStats()
|
|
self.autotune_remote = _GlobalItemStats()
|
|
self.bundled_autotune = _GlobalItemStats()
|
|
self.fx_graph = _GlobalItemStats()
|
|
self.triton = _GlobalItemStats()
|
|
self.aot_autograd = _GlobalItemStats()
|
|
self.dynamo_pgo = _GlobalItemStats()
|
|
|
|
def reset(self) -> None:
|
|
self.autotune_local.reset()
|
|
self.autotune_remote.reset()
|
|
self.bundled_autotune.reset()
|
|
self.fx_graph.reset()
|
|
self.triton.reset()
|
|
self.aot_autograd.reset()
|
|
self.dynamo_pgo.reset()
|
|
|
|
def get_stat(self, name: str) -> _GlobalItemStats:
|
|
return getattr(self, name)
|
|
|
|
def report(self):
|
|
subs = (
|
|
("autotune_local", self.autotune_local),
|
|
("autotune_remote", self.autotune_remote),
|
|
("bundled_autotune", self.bundled_autotune),
|
|
("fx_graph", self.fx_graph),
|
|
("triton", self.triton),
|
|
("aot_autograd", self.aot_autograd),
|
|
("dynamo_pgo", self.dynamo_pgo),
|
|
)
|
|
|
|
print("Cache Stats:", file=sys.stderr)
|
|
for name, sub in subs:
|
|
print(f" {name}: {sub}", file=sys.stderr)
|
|
|
|
print("Cache Entries:", file=sys.stderr)
|
|
for name, sub in subs:
|
|
if sub.cache:
|
|
print(f" {name}:", file=sys.stderr)
|
|
for k, v in sorted(sub.cache.items()):
|
|
v = repr(v)
|
|
if len(v) > 100:
|
|
v = v[:100] + "..."
|
|
print(f" {k!r}: {v}", file=sys.stderr)
|
|
|
|
|
|
global_stats = _GlobalStats()
|
|
|
|
|
|
class MockBackend(RemoteCacheBackend[Any]):
|
|
def __init__(self, name: str) -> None:
|
|
self._name = name
|
|
|
|
@staticmethod
|
|
def with_name(name: str) -> Callable[[], MockBackend]:
|
|
def wrapper() -> MockBackend:
|
|
return MockBackend(name)
|
|
|
|
return wrapper
|
|
|
|
@override
|
|
def _get(self, key: str) -> Optional[Any]:
|
|
stat = global_stats.get_stat(self._name)
|
|
if key in stat.cache:
|
|
stat += Stats(num_get_hit=1)
|
|
return stat.cache.get(key)
|
|
else:
|
|
stat += Stats(num_get_miss=1)
|
|
return None
|
|
|
|
@override
|
|
def _put(self, key: str, data: Any) -> None:
|
|
stat = global_stats.get_stat(self._name)
|
|
stat += Stats(num_put=1)
|
|
stat.cache[key] = data
|
|
|
|
|
|
# List of configs for each cache
|
|
_CACHE_CONFIG_EN = (
|
|
"fx_graph_cache",
|
|
"fx_graph_remote_cache",
|
|
"autotune_local_cache",
|
|
"autotune_remote_cache",
|
|
"bundled_autotune_remote_cache",
|
|
)
|
|
|
|
|
|
class PatchCaches(contextlib.AbstractContextManager):
|
|
@classmethod
|
|
def setUp(cls):
|
|
# If this test is using PatchCaches then disable all the caches by
|
|
# default, letting the tests turn them on explicitly. This is because
|
|
# tests using PatchCaches will often want to check stats explicitly.
|
|
cls._savedCacheState = {}
|
|
for name in _CACHE_CONFIG_EN:
|
|
if hasattr(config, name):
|
|
cls._savedCacheState[name] = getattr(config, name)
|
|
setattr(config, name, False)
|
|
|
|
@classmethod
|
|
def tearDown(cls):
|
|
# Restore cache defaults
|
|
for name in _CACHE_CONFIG_EN:
|
|
delattr(config, name)
|
|
if name in cls._savedCacheState:
|
|
setattr(config, name, cls._savedCacheState[name])
|
|
|
|
def __init__(self) -> None:
|
|
self._stack = contextlib.ExitStack()
|
|
|
|
def __enter__(self) -> Self:
|
|
global_stats.reset()
|
|
self._stack.__enter__()
|
|
|
|
ctx = patch(
|
|
"torch._inductor.runtime.autotune_cache.LocalAutotuneCache.backend_override_cls",
|
|
MockBackend.with_name("autotune_local"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls",
|
|
MockBackend.with_name("autotune_remote"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.remote_cache.RemoteBundledAutotuneCache.backend_override_cls",
|
|
MockBackend.with_name("bundled_autotune"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls",
|
|
MockBackend.with_name("fx_graph"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.remote_cache.RemoteAOTAutogradCache.backend_override_cls",
|
|
MockBackend.with_name("aot_autograd"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls",
|
|
MockBackend.with_name("dynamo_pgo"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
if config.is_fbcode():
|
|
ctx = patch(
|
|
"torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls",
|
|
MockBackend.with_name("autotune_remote"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.fb.remote_cache.FbRemoteBundledAutotuneCache.backend_override_cls",
|
|
MockBackend.with_name("bundled_autotune"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls",
|
|
MockBackend.with_name("fx_graph"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls",
|
|
MockBackend.with_name("triton"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.fb.remote_cache.FbRemoteAOTAutogradCache.backend_override_cls",
|
|
MockBackend.with_name("aot_autograd"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
ctx = patch(
|
|
"torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls",
|
|
MockBackend.with_name("dynamo_pgo"),
|
|
)
|
|
self._stack.enter_context(ctx)
|
|
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: Optional[Type[BaseException]],
|
|
exc_value: Optional[BaseException],
|
|
traceback: Optional[TracebackType],
|
|
) -> None:
|
|
self._stack.__exit__(exc_type, exc_value, traceback)
|