mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
inductor codecache: include private inductor configs in cache key (#153672)
Fixes https://github.com/pytorch/torchtitan/issues/1185 It looks like inductor's logic to include inductor configs in the cache key skips configs with a leading underscore by default. This came up in torchtitan - there's an asyncTP pipelining pass in inductor gated by a private config, and by not caching on the config we were attempting to use asyncTP when we shouldn't be. I'm not sure how worried we should be on the blast radius of this change. On the one hand: (1) it technically fixes any silent correctness issues in the cache around any other private inductor configs (it looks like there are a few) (2) there is some risk that there are some "harmless" configs that we are now including in the key, which may increase false negatives. I do see that there is an explicit list for "configs we want to ignore for caching" (`_save_config_ignore`), so my hope is that all harmless configs are already encapsulated there. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153672 Approved by: https://github.com/oulgen ghstack dependencies: #153766
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b6fd277f9
commit
2c1cb38d95
@ -56,6 +56,7 @@ from torch.testing._internal.inductor_utils import (
|
|||||||
requires_gpu,
|
requires_gpu,
|
||||||
requires_triton,
|
requires_triton,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.logging_utils import multiple_logs_to_string
|
||||||
from torch.testing._internal.triton_utils import requires_cuda
|
from torch.testing._internal.triton_utils import requires_cuda
|
||||||
|
|
||||||
|
|
||||||
@ -2120,6 +2121,90 @@ class TestFxGraphCacheHashing(TestCase):
|
|||||||
pickler.dumps(details3),
|
pickler.dumps(details3),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_hash_private_config_changes(self):
|
||||||
|
"""
|
||||||
|
Test that private config settings affect hashes.
|
||||||
|
"""
|
||||||
|
with config.patch({"_micro_pipeline_tp": False}):
|
||||||
|
details1 = FxGraphHashDetails(None, [], {}, [])
|
||||||
|
details2 = FxGraphHashDetails(None, [], {}, [])
|
||||||
|
|
||||||
|
with config.patch({"_micro_pipeline_tp": True}):
|
||||||
|
details3 = FxGraphHashDetails(None, [], {}, [])
|
||||||
|
|
||||||
|
gm = torch.fx.GraphModule({}, torch.fx.Graph())
|
||||||
|
pickler = FxGraphCachePickler(gm)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
pickler.dumps(details1),
|
||||||
|
pickler.dumps(details2),
|
||||||
|
)
|
||||||
|
self.assertNotEqual(
|
||||||
|
pickler.dumps(details1),
|
||||||
|
pickler.dumps(details3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_non_serializable_custom_passes_causes_cache_miss(self):
|
||||||
|
class Mod(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.param = torch.nn.Parameter(torch.rand(4, 4))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x @ self.param
|
||||||
|
|
||||||
|
mod1 = Mod()
|
||||||
|
mod_compiled = torch.compile(mod1)
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.rand(4, 4)
|
||||||
|
# miss
|
||||||
|
mod_compiled(x)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
# hit
|
||||||
|
torch._dynamo.reset()
|
||||||
|
mod_compiled(x)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
# hit
|
||||||
|
mod_compiled(x)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||||
|
with config.patch({"_fuse_ddp_communication_passes": ["new_pass_foo_bar"]}):
|
||||||
|
# miss (private config changed)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
mod_compiled(x)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
(codecache_stream,), ctx = multiple_logs_to_string(
|
||||||
|
"torch._inductor.codecache", "codecache"
|
||||||
|
)
|
||||||
|
with ctx(), config.patch(
|
||||||
|
{"_fuse_ddp_communication_passes": [lambda *args: None]}
|
||||||
|
):
|
||||||
|
# bypass (custom pass is not serializable)
|
||||||
|
mod_compiled(x)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
counters.clear()
|
||||||
|
# assert that our bypass is explicit
|
||||||
|
codecache_logs = codecache_stream.getvalue().strip()
|
||||||
|
self.assertTrue(
|
||||||
|
"Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'"
|
||||||
|
in codecache_logs
|
||||||
|
)
|
||||||
|
|
||||||
def test_hash_custom_passes(self):
|
def test_hash_custom_passes(self):
|
||||||
"""
|
"""
|
||||||
Test CustomGraphPass usage.
|
Test CustomGraphPass usage.
|
||||||
|
@ -881,7 +881,7 @@ class FxGraphHashDetails:
|
|||||||
# Also hash on various system info (including the triton compiler version).
|
# Also hash on various system info (including the triton compiler version).
|
||||||
self.torch_version = torch_key()
|
self.torch_version = torch_key()
|
||||||
self.system_info = CacheBase.get_system()
|
self.system_info = CacheBase.get_system()
|
||||||
self.inductor_config = config.save_config_portable()
|
self.inductor_config = config.save_config_portable(ignore_private_configs=False)
|
||||||
# Custom post grad passes should provide an ID to hash.
|
# Custom post grad passes should provide an ID to hash.
|
||||||
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
|
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
|
||||||
config.post_grad_custom_pre_pass
|
config.post_grad_custom_pre_pass
|
||||||
@ -889,6 +889,36 @@ class FxGraphHashDetails:
|
|||||||
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
|
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
|
||||||
config.post_grad_custom_post_pass
|
config.post_grad_custom_post_pass
|
||||||
)
|
)
|
||||||
|
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
|
||||||
|
config._pre_fusion_custom_pass
|
||||||
|
)
|
||||||
|
self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe(
|
||||||
|
config._fuse_ddp_communication_passes
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is mainly added to handle these two inductor configs, which are (unfortunately)
|
||||||
|
# sometimes cache safe:
|
||||||
|
# - _pre_fusion_custom_pass
|
||||||
|
# - _fuse_ddp_communication_passes
|
||||||
|
# Their types can be found in `torch/_inductor/config.py`, but:
|
||||||
|
# - if they are string names, we can cache them safely (one is by default)
|
||||||
|
# - if any of them are set to custom callables, we will need to cache miss
|
||||||
|
# Future work is for someone to find any places where these functions are used
|
||||||
|
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
|
||||||
|
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
|
||||||
|
if not custom_pass:
|
||||||
|
return None
|
||||||
|
if isinstance(custom_pass, list):
|
||||||
|
return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass]
|
||||||
|
if isinstance(custom_pass, str):
|
||||||
|
return custom_pass
|
||||||
|
if isinstance(custom_pass, CustomGraphPass):
|
||||||
|
return custom_pass.uuid()
|
||||||
|
if callable(custom_pass):
|
||||||
|
# Returning None is safe here because we raise an explicit bypass error
|
||||||
|
# later if we detect these passes are set to callables
|
||||||
|
return None
|
||||||
|
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
|
||||||
|
|
||||||
def _get_custom_pass_detail(
|
def _get_custom_pass_detail(
|
||||||
self, custom_pass: CustomGraphPassType
|
self, custom_pass: CustomGraphPassType
|
||||||
@ -1366,6 +1396,14 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
|||||||
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
|
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
|
||||||
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
|
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
|
||||||
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
||||||
|
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
|
||||||
|
# and ensure they are not passing us raw callables
|
||||||
|
if config._pre_fusion_custom_pass is not None:
|
||||||
|
if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass):
|
||||||
|
raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass")
|
||||||
|
for p in config._fuse_ddp_communication_passes:
|
||||||
|
if callable(p) and not isinstance(p, CustomGraphPass):
|
||||||
|
raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass")
|
||||||
|
|
||||||
# Freezing can embed constants that wouldn't be static across runs.
|
# Freezing can embed constants that wouldn't be static across runs.
|
||||||
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
|
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
|
||||||
|
@ -1635,6 +1635,8 @@ _save_config_ignore: list[str] = [
|
|||||||
"aot_inductor.dump_aoti_minifier",
|
"aot_inductor.dump_aoti_minifier",
|
||||||
"post_grad_custom_pre_pass",
|
"post_grad_custom_pre_pass",
|
||||||
"post_grad_custom_post_pass",
|
"post_grad_custom_post_pass",
|
||||||
|
"_fuse_ddp_communication_passes",
|
||||||
|
"_pre_fusion_custom_pass",
|
||||||
]
|
]
|
||||||
|
|
||||||
_cache_config_ignore_prefix: list[str] = [
|
_cache_config_ignore_prefix: list[str] = [
|
||||||
@ -1648,6 +1650,8 @@ _cache_config_ignore_prefix: list[str] = [
|
|||||||
# see CustomGraphPass; these are handled specially
|
# see CustomGraphPass; these are handled specially
|
||||||
"post_grad_custom_post_pass",
|
"post_grad_custom_post_pass",
|
||||||
"post_grad_custom_pre_pass",
|
"post_grad_custom_pre_pass",
|
||||||
|
"_fuse_ddp_communication_passes",
|
||||||
|
"_pre_fusion_custom_pass",
|
||||||
# tests assume that changes here don't invalidate cache
|
# tests assume that changes here don't invalidate cache
|
||||||
"always_complex_memory_overlap_TESTING_ONLY",
|
"always_complex_memory_overlap_TESTING_ONLY",
|
||||||
]
|
]
|
||||||
|
@ -508,9 +508,13 @@ class ConfigModule(ModuleType):
|
|||||||
protocol=2,
|
protocol=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_config_portable(self) -> dict[str, Any]:
|
def save_config_portable(
|
||||||
|
self, *, ignore_private_configs: bool = True
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Convert config to portable format"""
|
"""Convert config to portable format"""
|
||||||
prefixes = ["_"]
|
prefixes = []
|
||||||
|
if ignore_private_configs:
|
||||||
|
prefixes.append("_")
|
||||||
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
|
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
|
||||||
return self._get_dict(ignored_prefixes=prefixes)
|
return self._get_dict(ignored_prefixes=prefixes)
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ Note that the import should happen before the call to install_config_module(), o
|
|||||||
assert TYPE_CHECKING, "Do not use at runtime"
|
assert TYPE_CHECKING, "Do not use at runtime"
|
||||||
|
|
||||||
def save_config() -> bytes: ...
|
def save_config() -> bytes: ...
|
||||||
def save_config_portable() -> dict[str, Any]: ...
|
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
|
||||||
def codegen_config() -> str: ...
|
def codegen_config() -> str: ...
|
||||||
def get_hash() -> bytes: ...
|
def get_hash() -> bytes: ...
|
||||||
def to_dict() -> dict[str, Any]: ...
|
def to_dict() -> dict[str, Any]: ...
|
||||||
|
Reference in New Issue
Block a user