Add torch compile force disable caches alias (#158072)

Bunch of people keep thinking current alias only disables inductor cache because it has the name inductor in it. lets globalize the name

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158072
Approved by: https://github.com/ezyang
This commit is contained in:
Oguz Ulgen
2025-08-02 10:22:34 -07:00
committed by PyTorch MergeBot
parent d2792f51b2
commit a29ed5e1ac
6 changed files with 21 additions and 14 deletions

View File

@ -717,5 +717,5 @@ backtrace is slow and very spammy so it is not included by default with extended
In order to measure the cold start compilation time or debug a cache corruption,
it is possible pass `TORCHINDUCTOR_FORCE_DISABLE_CACHES=1` or set
`torch._inductor.config.force_disable_caches = True` which will override any
`torch.compiler.config.force_disable_caches = True` which will override any
other caching config option and disable all compile time caching.

View File

@ -95,15 +95,14 @@ class TestScheduler(TestCase):
{
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
"force_disable_caches": True,
},
{
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON,ATEN",
"force_disable_caches": True,
},
],
)
@torch._inductor.config.patch({"force_disable_caches": True})
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
def test_flop_counter_op(self, device, dtype, options):
if device == "cpu":

View File

@ -521,9 +521,9 @@ def process_automatic_dynamic(
def get_cache_key() -> Optional[str]:
# TODO: info versions of these logs that log only once
if torch._inductor.config.force_disable_caches:
if torch.compiler.config.force_disable_caches:
warn_once(
"dynamo_pgo force disabled by torch._inductor.config.force_disable_caches"
"dynamo_pgo force disabled by torch.compiler.config.force_disable_caches"
)
return None
@ -566,7 +566,7 @@ def code_state_path(cache_key: str) -> Optional[str]:
def should_use_remote_dynamo_pgo_cache() -> bool:
if torch._inductor.config.force_disable_caches:
if torch.compiler.config.force_disable_caches:
return False
if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None:

View File

@ -95,7 +95,7 @@ class FXGraphCacheMiss(BypassAOTAutogradCache):
def should_use_remote_autograd_cache():
if torch._inductor.config.force_disable_caches:
if torch.compiler.config.force_disable_caches:
return False
if config.enable_remote_autograd_cache is not None:
return config.enable_remote_autograd_cache
@ -116,7 +116,7 @@ def should_use_remote_autograd_cache():
def should_use_local_autograd_cache():
if torch._inductor.config.force_disable_caches:
if torch.compiler.config.force_disable_caches:
return False
return config.enable_autograd_cache

View File

@ -138,12 +138,8 @@ autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
# None: Not set -- Off for OSS, JustKnobs based for internal
bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default()
# Force disabled all inductor level caching -- This will override any other caching flag
force_disable_caches: bool = Config(
justknob="pytorch/remote_cache:force_disable_caches",
env_name_force="TORCHINDUCTOR_FORCE_DISABLE_CACHES",
default=False,
)
# See torch.compiler.config.force_disable_caches
force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches")
# Unsafe way to skip dynamic shape guards to get faster cache load
unsafe_skip_cache_dynamic_shape_guards: bool = False

View File

@ -66,6 +66,18 @@ Tag to be included in the cache key generation for all torch compile caching.
A common use case for such a tag is to break caches.
"""
force_disable_caches: bool = Config(
justknob="pytorch/remote_cache:force_disable_caches",
env_name_force=[
"TORCHINDUCTOR_FORCE_DISABLE_CACHES",
"TORCH_COMPILE_FORCE_DISABLE_CACHES",
],
default=False,
)
"""
Force disables all caching -- This will take precedence over and override any other caching flag
"""
dynamic_sources: str = Config(
env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default=""
)