mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support caching if joint_custom_pre_pass/joint_custom_post_pass implement the proper interface (#157990)
Summary: Essentially, treat joint_custom_pre_pass/joint_custom_post_pass the same as post_grad_custom_post_pass/post_grad_custom_pre_pass. Test Plan: More unit tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/157990 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
e172309880
commit
5bd7804be2
@ -2038,6 +2038,60 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
result = torch.compile(f, backend=backend)(static_x)
|
||||
self.assertEqual(result, static_x * 3)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
def test_custom_pass_handling(self):
|
||||
"""
|
||||
Test that properly-registered custom hooks allow caching.
|
||||
"""
|
||||
|
||||
class TestCustomGraphPass(CustomGraphPass):
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
return None
|
||||
|
||||
def uuid(self) -> Optional[Union[bytes, str]]:
|
||||
return "uuid"
|
||||
|
||||
def fn(a, b):
|
||||
return torch.mm(a, b)
|
||||
|
||||
a = torch.rand(8, 32, device="cpu")
|
||||
b = torch.rand(32, 8, device="cpu")
|
||||
compiled_fn = torch.compile(fn)
|
||||
|
||||
# The cache should be bypassed if a custom hook doesn't use CustomGraphPass.
|
||||
with config.patch({"post_grad_custom_pre_pass": lambda x: x}):
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
# With proper usage, we expect normal caching.
|
||||
custom_pass = TestCustomGraphPass()
|
||||
with config.patch(
|
||||
{
|
||||
"post_grad_custom_pre_pass": custom_pass,
|
||||
"post_grad_custom_post_pass": custom_pass,
|
||||
"joint_custom_pre_pass": custom_pass,
|
||||
"joint_custom_post_pass": custom_pass,
|
||||
}
|
||||
):
|
||||
self.reset()
|
||||
counters.clear()
|
||||
|
||||
# Cache miss
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
self.reset()
|
||||
counters.clear()
|
||||
|
||||
# Cache hit
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
|
||||
class TestFxGraphCacheHashing(TestCase):
|
||||
def test_parameter_constants(self):
|
||||
|
@ -66,6 +66,17 @@ def change_cos_pass(graph):
|
||||
node.target = aten.sin.default
|
||||
|
||||
|
||||
class ChangeCosCustomPass(CustomGraphPass):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, g: torch.fx.graph.Graph):
|
||||
change_cos_pass(g)
|
||||
|
||||
def uuid(self) -> bytes:
|
||||
return get_hash_for_files((__file__,))
|
||||
|
||||
|
||||
class TestPostGradCustomPrePostPass(TestCustomPassBase):
|
||||
# mkldnn fusion's pattern_matcher
|
||||
# (torch/_inductor/fx_passes/mkldnn_fusion.py),
|
||||
@ -134,7 +145,7 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
|
||||
return x1.relu()
|
||||
|
||||
def test_custom_joint_pass_pre(self):
|
||||
with config.patch(joint_custom_pre_pass=change_cos_pass):
|
||||
with config.patch(joint_custom_pre_pass=ChangeCosCustomPass()):
|
||||
|
||||
def g(x):
|
||||
return x.sin().sin().sin()
|
||||
@ -146,7 +157,7 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
|
||||
torch.testing.assert_close(torch.compile(f)(x), g(x))
|
||||
|
||||
def test_custom_joint_pass_post(self):
|
||||
with config.patch(joint_custom_post_pass=change_cos_pass):
|
||||
with config.patch(joint_custom_post_pass=ChangeCosCustomPass()):
|
||||
|
||||
def g(x):
|
||||
return x.sin().sin().sin()
|
||||
|
@ -831,6 +831,12 @@ class FxGraphHashDetails:
|
||||
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
|
||||
config.post_grad_custom_post_pass
|
||||
)
|
||||
self.joint_custom_pre_pass = self._get_custom_pass_detail(
|
||||
config.joint_custom_pre_pass
|
||||
)
|
||||
self.joint_custom_post_pass = self._get_custom_pass_detail(
|
||||
config.joint_custom_post_pass
|
||||
)
|
||||
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
|
||||
config._pre_fusion_custom_pass
|
||||
)
|
||||
@ -1344,6 +1350,10 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
|
||||
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
|
||||
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
||||
# Same with the joint custom passes
|
||||
for p in (config.joint_custom_pre_pass, config.joint_custom_post_pass):
|
||||
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
|
||||
raise BypassFxGraphCache("Unsupported joint custom pass")
|
||||
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
|
||||
# and ensure they are not passing us raw callables
|
||||
if config._pre_fusion_custom_pass is not None:
|
||||
|
@ -262,8 +262,8 @@ post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType
|
||||
post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
|
||||
# Registers a custom joint graph pass.
|
||||
joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
|
||||
joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None
|
||||
joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
|
||||
# Registers a custom pregrad pass. Note that the pre-grad IR is 1.
|
||||
# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
|
||||
@ -1766,6 +1766,8 @@ _cache_config_ignore_prefix: list[str] = [
|
||||
# see CustomGraphPass; these are handled specially
|
||||
"post_grad_custom_post_pass",
|
||||
"post_grad_custom_pre_pass",
|
||||
"joint_custom_pre_pass",
|
||||
"joint_custom_post_pass",
|
||||
"_fuse_ddp_communication_passes",
|
||||
"_pre_fusion_custom_pass",
|
||||
# tests assume that changes here don't invalidate cache
|
||||
|
Reference in New Issue
Block a user