Rename inductor cache (#156128)

Requested by Simon on a different PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156128
Approved by: https://github.com/xmfan
This commit is contained in:
Oguz Ulgen
2025-06-16 15:28:16 -07:00
committed by PyTorch MergeBot
parent 45382b284d
commit a2a75be0f8
48 changed files with 232 additions and 232 deletions

View File

@ -1015,29 +1015,6 @@ def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
return input_devices | out_devices
_registered_caches: list[Any] = []
def clear_on_fresh_inductor_cache(obj: Any) -> Any:
"""
Use this decorator to register any caches that should be cache_clear'd
with fresh_inductor_cache().
"""
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
raise AttributeError(f"{obj} does not have a cache_clear method")
_registered_caches.append(obj)
return obj
def clear_inductor_caches() -> None:
"""
Clear all registered caches.
"""
for obj in _registered_caches:
obj.cache_clear()
import gc
@ -1070,19 +1047,42 @@ def unload_xpu_triton_pyds() -> None:
gc.collect()
_registered_caches: list[Any] = []
def clear_on_fresh_cache(obj: Any) -> Any:
"""
Use this decorator to register any caches that should be cache_clear'd
with fresh_cache().
"""
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
raise AttributeError(f"{obj} does not have a cache_clear method")
_registered_caches.append(obj)
return obj
def clear_caches() -> None:
"""
Clear all registered caches.
"""
for obj in _registered_caches:
obj.cache_clear()
@contextlib.contextmanager
def fresh_inductor_cache(
def fresh_cache(
cache_entries: Optional[dict[str, Any]] = None,
dir: Optional[str] = None,
delete: bool = True,
) -> Iterator[None]:
"""
Contextmanager that provides a clean tmp cachedir for inductor.
Contextmanager that provides a clean tmp cachedir for pt2 caches.
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
generated with this cache instance.
"""
clear_inductor_caches()
clear_caches()
inductor_cache_dir = tempfile.mkdtemp(dir=dir)
try:
@ -1123,7 +1123,13 @@ def fresh_inductor_cache(
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
raise
finally:
clear_inductor_caches()
clear_caches()
# Deprecated functions -- only keeping them for BC reasons
clear_on_fresh_inductor_cache = clear_on_fresh_cache
clear_inductor_caches = clear_caches
fresh_inductor_cache = fresh_cache
def argsort(seq: Sequence[Any]) -> list[int]: