From 5da573c42c332bc68d4b7946c69f690a876d951a Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 5 Sep 2025 04:58:15 +0000 Subject: [PATCH] [PGO] handle PGO profile merges (#162097) Avoid merges from extra PGO key, if same source has different rank. Unlikely to happen (needs code hash match & source variable type to change), but being safe. Differential Revision: D81299840 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162097 Approved by: https://github.com/bobrenjc93 --- test/dynamo/test_pgo.py | 44 +++++++++++++++++++++++++++++++++++++++++ torch/_dynamo/pgo.py | 14 ++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index de7679ed1863..ce2fda138729 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -453,6 +453,50 @@ def run(cnt): f(t(4, 2), t(2, 4)) self.assertEqual(cnts.frame_count, 1) + def test_profile_merges(self): + from torch._dynamo.pgo import auto_dynamic, merge_pgo_entry + + @torch.compile(backend="eager", fullgraph=True) + def f(ints, t_scalar, tensors): + # arbitrary compute + return ints[0] + ints[1], t_scalar + 1, [t + 1 for t in tensors] + + # single static run + f( + [0, 2], + torch.tensor(0), + [ + torch.randn(2), + torch.randn(2, 2), + torch.randn(4, 4), + ], + ) + # collect profiles + profile = next( + iter(torch._dynamo.pgo.get_code_state().values()) + ).automatic_dynamic + i0, i1 = profile["L['ints'][0]"], profile["L['ints'][1]"] + ts = profile["L['t_scalar]"] + t0, t1, t2 = ( + profile["L['tensors'][0]"], + profile["L['tensors'][1]"], + profile["L['tensors'][2]"], + ) + # merging same scalar, or tensor into scalar -> no-op + merge_pgo_entry(i0, i0) + merge_pgo_entry(ts, i0) + merge_pgo_entry(t0, i0) + self.assertEqual(i0.scalar, 0) + # merging different scalars -> dynamic + merge_pgo_entry(i1, i0) + self.assertEqual(i0.scalar, auto_dynamic) + # merging different rank tensors -> static + merge_pgo_entry(t0, t2) + self.assertEqual(t2.size, (4, 4)) + # merging same rank tensors -> dynamic + merge_pgo_entry(t1, t2) + self.assertEqual(t2.size, (auto_dynamic, auto_dynamic)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index ba41fabfa0f4..1a2c98ee6c7d 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -671,6 +671,16 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: return code_state_str +def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None: + def rank(entry: FrameStateSizeEntry) -> int: + if not isinstance(entry.size, tuple): # scalar + return -1 + return len(entry.size) + + if rank(src) == rank(dst): # both tensors same rank, or both scalars + dst |= src + + @CacheArtifactFactory.register class PGOCacheArtifact(CacheArtifact): @override @@ -825,7 +835,9 @@ def add_extra_remote_code_state(cache_key: str) -> None: # where one entry might be 1-d, the other 2-d. # or if entries are of different types? # with local source naming, could be scalar vs. tensor - _CODE_STATE[code_id].automatic_dynamic[src] |= entry + merge_pgo_entry( + entry, _CODE_STATE[code_id].automatic_dynamic[src] + ) else: _CODE_STATE[code_id] = state # log to tlparse