mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename inductor cache (#156128)
Requested by Simon on a different PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/156128 Approved by: https://github.com/xmfan
This commit is contained in:
committed by
PyTorch MergeBot
parent
45382b284d
commit
a2a75be0f8
@ -1015,29 +1015,6 @@ def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
|
||||
return input_devices | out_devices
|
||||
|
||||
|
||||
_registered_caches: list[Any] = []
|
||||
|
||||
|
||||
def clear_on_fresh_inductor_cache(obj: Any) -> Any:
|
||||
"""
|
||||
Use this decorator to register any caches that should be cache_clear'd
|
||||
with fresh_inductor_cache().
|
||||
"""
|
||||
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
|
||||
raise AttributeError(f"{obj} does not have a cache_clear method")
|
||||
|
||||
_registered_caches.append(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def clear_inductor_caches() -> None:
|
||||
"""
|
||||
Clear all registered caches.
|
||||
"""
|
||||
for obj in _registered_caches:
|
||||
obj.cache_clear()
|
||||
|
||||
|
||||
import gc
|
||||
|
||||
|
||||
@ -1070,19 +1047,42 @@ def unload_xpu_triton_pyds() -> None:
|
||||
gc.collect()
|
||||
|
||||
|
||||
_registered_caches: list[Any] = []
|
||||
|
||||
|
||||
def clear_on_fresh_cache(obj: Any) -> Any:
|
||||
"""
|
||||
Use this decorator to register any caches that should be cache_clear'd
|
||||
with fresh_cache().
|
||||
"""
|
||||
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
|
||||
raise AttributeError(f"{obj} does not have a cache_clear method")
|
||||
|
||||
_registered_caches.append(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def clear_caches() -> None:
|
||||
"""
|
||||
Clear all registered caches.
|
||||
"""
|
||||
for obj in _registered_caches:
|
||||
obj.cache_clear()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fresh_inductor_cache(
|
||||
def fresh_cache(
|
||||
cache_entries: Optional[dict[str, Any]] = None,
|
||||
dir: Optional[str] = None,
|
||||
delete: bool = True,
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Contextmanager that provides a clean tmp cachedir for inductor.
|
||||
Contextmanager that provides a clean tmp cachedir for pt2 caches.
|
||||
|
||||
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
|
||||
generated with this cache instance.
|
||||
"""
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
inductor_cache_dir = tempfile.mkdtemp(dir=dir)
|
||||
try:
|
||||
@ -1123,7 +1123,13 @@ def fresh_inductor_cache(
|
||||
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
|
||||
raise
|
||||
finally:
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
|
||||
# Deprecated functions -- only keeping them for BC reasons
|
||||
clear_on_fresh_inductor_cache = clear_on_fresh_cache
|
||||
clear_inductor_caches = clear_caches
|
||||
fresh_inductor_cache = fresh_cache
|
||||
|
||||
|
||||
def argsort(seq: Sequence[Any]) -> list[int]:
|
||||
|
||||
Reference in New Issue
Block a user