Support caching if joint_custom_pre_pass/joint_custom_post_pass implement the proper interface (#157990)

Summary: Essentially, treat joint_custom_pre_pass/joint_custom_post_pass the same as post_grad_custom_post_pass/post_grad_custom_pre_pass.

Test Plan: More unit tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157990
Approved by: https://github.com/oulgen
This commit is contained in:
Sam Larsen
2025-07-09 18:44:54 -07:00
committed by PyTorch MergeBot
parent e172309880
commit 5bd7804be2
4 changed files with 81 additions and 4 deletions

View File

@ -2038,6 +2038,60 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
result = torch.compile(f, backend=backend)(static_x)
self.assertEqual(result, static_x * 3)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
def test_custom_pass_handling(self):
"""
Test that properly-registered custom hooks allow caching.
"""
class TestCustomGraphPass(CustomGraphPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
return None
def uuid(self) -> Optional[Union[bytes, str]]:
return "uuid"
def fn(a, b):
return torch.mm(a, b)
a = torch.rand(8, 32, device="cpu")
b = torch.rand(32, 8, device="cpu")
compiled_fn = torch.compile(fn)
# The cache should be bypassed if a custom hook doesn't use CustomGraphPass.
with config.patch({"post_grad_custom_pre_pass": lambda x: x}):
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# With proper usage, we expect normal caching.
custom_pass = TestCustomGraphPass()
with config.patch(
{
"post_grad_custom_pre_pass": custom_pass,
"post_grad_custom_post_pass": custom_pass,
"joint_custom_pre_pass": custom_pass,
"joint_custom_post_pass": custom_pass,
}
):
self.reset()
counters.clear()
# Cache miss
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.reset()
counters.clear()
# Cache hit
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
class TestFxGraphCacheHashing(TestCase):
def test_parameter_constants(self):

View File

@ -66,6 +66,17 @@ def change_cos_pass(graph):
node.target = aten.sin.default
class ChangeCosCustomPass(CustomGraphPass):
def __init__(self) -> None:
super().__init__()
def __call__(self, g: torch.fx.graph.Graph):
change_cos_pass(g)
def uuid(self) -> bytes:
return get_hash_for_files((__file__,))
class TestPostGradCustomPrePostPass(TestCustomPassBase):
# mkldnn fusion's pattern_matcher
# (torch/_inductor/fx_passes/mkldnn_fusion.py),
@ -134,7 +145,7 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
return x1.relu()
def test_custom_joint_pass_pre(self):
with config.patch(joint_custom_pre_pass=change_cos_pass):
with config.patch(joint_custom_pre_pass=ChangeCosCustomPass()):
def g(x):
return x.sin().sin().sin()
@ -146,7 +157,7 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
torch.testing.assert_close(torch.compile(f)(x), g(x))
def test_custom_joint_pass_post(self):
with config.patch(joint_custom_post_pass=change_cos_pass):
with config.patch(joint_custom_post_pass=ChangeCosCustomPass()):
def g(x):
return x.sin().sin().sin()

View File

@ -831,6 +831,12 @@ class FxGraphHashDetails:
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
config.post_grad_custom_post_pass
)
self.joint_custom_pre_pass = self._get_custom_pass_detail(
config.joint_custom_pre_pass
)
self.joint_custom_post_pass = self._get_custom_pass_detail(
config.joint_custom_post_pass
)
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
config._pre_fusion_custom_pass
)
@ -1344,6 +1350,10 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
raise BypassFxGraphCache("Unsupported post grad custom pass")
# Same with the joint custom passes
for p in (config.joint_custom_pre_pass, config.joint_custom_post_pass):
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
raise BypassFxGraphCache("Unsupported joint custom pass")
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
# and ensure they are not passing us raw callables
if config._pre_fusion_custom_pass is not None:

View File

@ -262,8 +262,8 @@ post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType
post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
# Registers a custom joint graph pass.
joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None
joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
# Registers a custom pregrad pass. Note that the pre-grad IR is 1.
# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
@ -1766,6 +1766,8 @@ _cache_config_ignore_prefix: list[str] = [
# see CustomGraphPass; these are handled specially
"post_grad_custom_post_pass",
"post_grad_custom_pre_pass",
"joint_custom_pre_pass",
"joint_custom_post_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
# tests assume that changes here don't invalidate cache