[PGO] handle PGO profile merges (#162097)

Avoid merges from extra PGO key, if same source has different rank. Unlikely to happen (needs code hash match & source variable type to change), but being safe.

Differential Revision: D81299840

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162097
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Pian Pawakapan
2025-09-05 04:58:15 +00:00
committed by PyTorch MergeBot
parent 494878a11b
commit 5da573c42c
2 changed files with 57 additions and 1 deletions

View File

@ -453,6 +453,50 @@ def run(cnt):
f(t(4, 2), t(2, 4))
self.assertEqual(cnts.frame_count, 1)
def test_profile_merges(self):
from torch._dynamo.pgo import auto_dynamic, merge_pgo_entry
@torch.compile(backend="eager", fullgraph=True)
def f(ints, t_scalar, tensors):
# arbitrary compute
return ints[0] + ints[1], t_scalar + 1, [t + 1 for t in tensors]
# single static run
f(
[0, 2],
torch.tensor(0),
[
torch.randn(2),
torch.randn(2, 2),
torch.randn(4, 4),
],
)
# collect profiles
profile = next(
iter(torch._dynamo.pgo.get_code_state().values())
).automatic_dynamic
i0, i1 = profile["L['ints'][0]"], profile["L['ints'][1]"]
ts = profile["L['t_scalar]"]
t0, t1, t2 = (
profile["L['tensors'][0]"],
profile["L['tensors'][1]"],
profile["L['tensors'][2]"],
)
# merging same scalar, or tensor into scalar -> no-op
merge_pgo_entry(i0, i0)
merge_pgo_entry(ts, i0)
merge_pgo_entry(t0, i0)
self.assertEqual(i0.scalar, 0)
# merging different scalars -> dynamic
merge_pgo_entry(i1, i0)
self.assertEqual(i0.scalar, auto_dynamic)
# merging different rank tensors -> static
merge_pgo_entry(t0, t2)
self.assertEqual(t2.size, (4, 4))
# merging same rank tensors -> dynamic
merge_pgo_entry(t1, t2)
self.assertEqual(t2.size, (auto_dynamic, auto_dynamic))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -671,6 +671,16 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
return code_state_str
def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None:
def rank(entry: FrameStateSizeEntry) -> int:
if not isinstance(entry.size, tuple): # scalar
return -1
return len(entry.size)
if rank(src) == rank(dst): # both tensors same rank, or both scalars
dst |= src
@CacheArtifactFactory.register
class PGOCacheArtifact(CacheArtifact):
@override
@ -825,7 +835,9 @@ def add_extra_remote_code_state(cache_key: str) -> None:
# where one entry might be 1-d, the other 2-d.
# or if entries are of different types?
# with local source naming, could be scalar vs. tensor
_CODE_STATE[code_id].automatic_dynamic[src] |= entry
merge_pgo_entry(
entry, _CODE_STATE[code_id].automatic_dynamic[src]
)
else:
_CODE_STATE[code_id] = state
# log to tlparse