inductor codecache: include private inductor configs in cache key (#153672)

Fixes https://github.com/pytorch/torchtitan/issues/1185

It looks like inductor's logic to include inductor configs in the cache key skips configs with a leading underscore by default. This came up in torchtitan - there's an asyncTP pipelining pass in inductor gated by a private config, and by not caching on the config we were attempting to use asyncTP when we shouldn't be.

I'm not sure how worried we should be on the blast radius of this change. On the one hand:

(1) it technically fixes any silent correctness issues in the cache around any other private inductor configs (it looks like there are a few)

(2) there is some risk that there are some "harmless" configs that we are now including in the key, which may increase false negatives. I do see that there is an explicit list for "configs we want to ignore for caching" (`_save_config_ignore`), so my hope is that all harmless configs are already encapsulated there.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153672
Approved by: https://github.com/oulgen
ghstack dependencies: #153766
This commit is contained in:
Brian Hirsh
2025-05-29 08:14:14 -07:00
committed by PyTorch MergeBot
parent 5b6fd277f9
commit 2c1cb38d95
5 changed files with 135 additions and 4 deletions

View File

@ -56,6 +56,7 @@ from torch.testing._internal.inductor_utils import (
requires_gpu,
requires_triton,
)
from torch.testing._internal.logging_utils import multiple_logs_to_string
from torch.testing._internal.triton_utils import requires_cuda
@ -2120,6 +2121,90 @@ class TestFxGraphCacheHashing(TestCase):
pickler.dumps(details3),
)
def test_hash_private_config_changes(self):
"""
Test that private config settings affect hashes.
"""
with config.patch({"_micro_pipeline_tp": False}):
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
with config.patch({"_micro_pipeline_tp": True}):
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_non_serializable_custom_passes_causes_cache_miss(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 4))
def forward(self, x):
return x @ self.param
mod1 = Mod()
mod_compiled = torch.compile(mod1)
with torch.no_grad():
x = torch.rand(4, 4)
# miss
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# hit
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
# hit
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
with config.patch({"_fuse_ddp_communication_passes": ["new_pass_foo_bar"]}):
# miss (private config changed)
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
(codecache_stream,), ctx = multiple_logs_to_string(
"torch._inductor.codecache", "codecache"
)
with ctx(), config.patch(
{"_fuse_ddp_communication_passes": [lambda *args: None]}
):
# bypass (custom pass is not serializable)
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
counters.clear()
# assert that our bypass is explicit
codecache_logs = codecache_stream.getvalue().strip()
self.assertTrue(
"Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'"
in codecache_logs
)
def test_hash_custom_passes(self):
"""
Test CustomGraphPass usage.

View File

@ -881,7 +881,7 @@ class FxGraphHashDetails:
# Also hash on various system info (including the triton compiler version).
self.torch_version = torch_key()
self.system_info = CacheBase.get_system()
self.inductor_config = config.save_config_portable()
self.inductor_config = config.save_config_portable(ignore_private_configs=False)
# Custom post grad passes should provide an ID to hash.
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
config.post_grad_custom_pre_pass
@ -889,6 +889,36 @@ class FxGraphHashDetails:
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
config.post_grad_custom_post_pass
)
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
config._pre_fusion_custom_pass
)
self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe(
config._fuse_ddp_communication_passes
)
# This is mainly added to handle these two inductor configs, which are (unfortunately)
# sometimes cache safe:
# - _pre_fusion_custom_pass
# - _fuse_ddp_communication_passes
# Their types can be found in `torch/_inductor/config.py`, but:
# - if they are string names, we can cache them safely (one is by default)
# - if any of them are set to custom callables, we will need to cache miss
# Future work is for someone to find any places where these functions are used
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
if not custom_pass:
return None
if isinstance(custom_pass, list):
return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass]
if isinstance(custom_pass, str):
return custom_pass
if isinstance(custom_pass, CustomGraphPass):
return custom_pass.uuid()
if callable(custom_pass):
# Returning None is safe here because we raise an explicit bypass error
# later if we detect these passes are set to callables
return None
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
def _get_custom_pass_detail(
self, custom_pass: CustomGraphPassType
@ -1366,6 +1396,14 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
raise BypassFxGraphCache("Unsupported post grad custom pass")
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
# and ensure they are not passing us raw callables
if config._pre_fusion_custom_pass is not None:
if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass):
raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass")
for p in config._fuse_ddp_communication_passes:
if callable(p) and not isinstance(p, CustomGraphPass):
raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass")
# Freezing can embed constants that wouldn't be static across runs.
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(

View File

@ -1635,6 +1635,8 @@ _save_config_ignore: list[str] = [
"aot_inductor.dump_aoti_minifier",
"post_grad_custom_pre_pass",
"post_grad_custom_post_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
]
_cache_config_ignore_prefix: list[str] = [
@ -1648,6 +1650,8 @@ _cache_config_ignore_prefix: list[str] = [
# see CustomGraphPass; these are handled specially
"post_grad_custom_post_pass",
"post_grad_custom_pre_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
# tests assume that changes here don't invalidate cache
"always_complex_memory_overlap_TESTING_ONLY",
]

View File

@ -508,9 +508,13 @@ class ConfigModule(ModuleType):
protocol=2,
)
def save_config_portable(self) -> dict[str, Any]:
def save_config_portable(
self, *, ignore_private_configs: bool = True
) -> dict[str, Any]:
"""Convert config to portable format"""
prefixes = ["_"]
prefixes = []
if ignore_private_configs:
prefixes.append("_")
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
return self._get_dict(ignored_prefixes=prefixes)

View File

@ -24,7 +24,7 @@ Note that the import should happen before the call to install_config_module(), o
assert TYPE_CHECKING, "Do not use at runtime"
def save_config() -> bytes: ...
def save_config_portable() -> dict[str, Any]: ...
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
def codegen_config() -> str: ...
def get_hash() -> bytes: ...
def to_dict() -> dict[str, Any]: ...