inductor codecache: include private inductor configs in cache key (#153672)

Fixes https://github.com/pytorch/torchtitan/issues/1185

It looks like inductor's logic to include inductor configs in the cache key skips configs with a leading underscore by default. This came up in torchtitan - there's an asyncTP pipelining pass in inductor gated by a private config, and by not caching on the config we were attempting to use asyncTP when we shouldn't be.

I'm not sure how worried we should be on the blast radius of this change. On the one hand:

(1) it technically fixes any silent correctness issues in the cache around any other private inductor configs (it looks like there are a few)

(2) there is some risk that there are some "harmless" configs that we are now including in the key, which may increase false negatives. I do see that there is an explicit list for "configs we want to ignore for caching" (`_save_config_ignore`), so my hope is that all harmless configs are already encapsulated there.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153672
Approved by: https://github.com/oulgen
ghstack dependencies: #153766
This commit is contained in:
Brian Hirsh
2025-05-29 08:14:14 -07:00
committed by PyTorch MergeBot
parent 5b6fd277f9
commit 2c1cb38d95
5 changed files with 135 additions and 4 deletions

View File

@ -56,6 +56,7 @@ from torch.testing._internal.inductor_utils import (
requires_gpu, requires_gpu,
requires_triton, requires_triton,
) )
from torch.testing._internal.logging_utils import multiple_logs_to_string
from torch.testing._internal.triton_utils import requires_cuda from torch.testing._internal.triton_utils import requires_cuda
@ -2120,6 +2121,90 @@ class TestFxGraphCacheHashing(TestCase):
pickler.dumps(details3), pickler.dumps(details3),
) )
def test_hash_private_config_changes(self):
"""
Test that private config settings affect hashes.
"""
with config.patch({"_micro_pipeline_tp": False}):
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
with config.patch({"_micro_pipeline_tp": True}):
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_non_serializable_custom_passes_causes_cache_miss(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 4))
def forward(self, x):
return x @ self.param
mod1 = Mod()
mod_compiled = torch.compile(mod1)
with torch.no_grad():
x = torch.rand(4, 4)
# miss
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# hit
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
# hit
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
with config.patch({"_fuse_ddp_communication_passes": ["new_pass_foo_bar"]}):
# miss (private config changed)
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
(codecache_stream,), ctx = multiple_logs_to_string(
"torch._inductor.codecache", "codecache"
)
with ctx(), config.patch(
{"_fuse_ddp_communication_passes": [lambda *args: None]}
):
# bypass (custom pass is not serializable)
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
counters.clear()
# assert that our bypass is explicit
codecache_logs = codecache_stream.getvalue().strip()
self.assertTrue(
"Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'"
in codecache_logs
)
def test_hash_custom_passes(self): def test_hash_custom_passes(self):
""" """
Test CustomGraphPass usage. Test CustomGraphPass usage.

View File

@ -881,7 +881,7 @@ class FxGraphHashDetails:
# Also hash on various system info (including the triton compiler version). # Also hash on various system info (including the triton compiler version).
self.torch_version = torch_key() self.torch_version = torch_key()
self.system_info = CacheBase.get_system() self.system_info = CacheBase.get_system()
self.inductor_config = config.save_config_portable() self.inductor_config = config.save_config_portable(ignore_private_configs=False)
# Custom post grad passes should provide an ID to hash. # Custom post grad passes should provide an ID to hash.
self.post_grad_custom_pre_pass = self._get_custom_pass_detail( self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
config.post_grad_custom_pre_pass config.post_grad_custom_pre_pass
@ -889,6 +889,36 @@ class FxGraphHashDetails:
self.post_grad_custom_post_pass = self._get_custom_pass_detail( self.post_grad_custom_post_pass = self._get_custom_pass_detail(
config.post_grad_custom_post_pass config.post_grad_custom_post_pass
) )
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
config._pre_fusion_custom_pass
)
self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe(
config._fuse_ddp_communication_passes
)
# This is mainly added to handle these two inductor configs, which are (unfortunately)
# sometimes cache safe:
# - _pre_fusion_custom_pass
# - _fuse_ddp_communication_passes
# Their types can be found in `torch/_inductor/config.py`, but:
# - if they are string names, we can cache them safely (one is by default)
# - if any of them are set to custom callables, we will need to cache miss
# Future work is for someone to find any places where these functions are used
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
if not custom_pass:
return None
if isinstance(custom_pass, list):
return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass]
if isinstance(custom_pass, str):
return custom_pass
if isinstance(custom_pass, CustomGraphPass):
return custom_pass.uuid()
if callable(custom_pass):
# Returning None is safe here because we raise an explicit bypass error
# later if we detect these passes are set to callables
return None
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
def _get_custom_pass_detail( def _get_custom_pass_detail(
self, custom_pass: CustomGraphPassType self, custom_pass: CustomGraphPassType
@ -1366,6 +1396,14 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass): for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
raise BypassFxGraphCache("Unsupported post grad custom pass") raise BypassFxGraphCache("Unsupported post grad custom pass")
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
# and ensure they are not passing us raw callables
if config._pre_fusion_custom_pass is not None:
if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass):
raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass")
for p in config._fuse_ddp_communication_passes:
if callable(p) and not isinstance(p, CustomGraphPass):
raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass")
# Freezing can embed constants that wouldn't be static across runs. # Freezing can embed constants that wouldn't be static across runs.
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check( if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(

View File

@ -1635,6 +1635,8 @@ _save_config_ignore: list[str] = [
"aot_inductor.dump_aoti_minifier", "aot_inductor.dump_aoti_minifier",
"post_grad_custom_pre_pass", "post_grad_custom_pre_pass",
"post_grad_custom_post_pass", "post_grad_custom_post_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
] ]
_cache_config_ignore_prefix: list[str] = [ _cache_config_ignore_prefix: list[str] = [
@ -1648,6 +1650,8 @@ _cache_config_ignore_prefix: list[str] = [
# see CustomGraphPass; these are handled specially # see CustomGraphPass; these are handled specially
"post_grad_custom_post_pass", "post_grad_custom_post_pass",
"post_grad_custom_pre_pass", "post_grad_custom_pre_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
# tests assume that changes here don't invalidate cache # tests assume that changes here don't invalidate cache
"always_complex_memory_overlap_TESTING_ONLY", "always_complex_memory_overlap_TESTING_ONLY",
] ]

View File

@ -508,9 +508,13 @@ class ConfigModule(ModuleType):
protocol=2, protocol=2,
) )
def save_config_portable(self) -> dict[str, Any]: def save_config_portable(
self, *, ignore_private_configs: bool = True
) -> dict[str, Any]:
"""Convert config to portable format""" """Convert config to portable format"""
prefixes = ["_"] prefixes = []
if ignore_private_configs:
prefixes.append("_")
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", [])) prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
return self._get_dict(ignored_prefixes=prefixes) return self._get_dict(ignored_prefixes=prefixes)

View File

@ -24,7 +24,7 @@ Note that the import should happen before the call to install_config_module(), o
assert TYPE_CHECKING, "Do not use at runtime" assert TYPE_CHECKING, "Do not use at runtime"
def save_config() -> bytes: ... def save_config() -> bytes: ...
def save_config_portable() -> dict[str, Any]: ... def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
def codegen_config() -> str: ... def codegen_config() -> str: ...
def get_hash() -> bytes: ... def get_hash() -> bytes: ...
def to_dict() -> dict[str, Any]: ... def to_dict() -> dict[str, Any]: ...