mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[PGO] handle PGO profile merges (#162097)
Avoid merges from extra PGO key, if same source has different rank. Unlikely to happen (needs code hash match & source variable type to change), but being safe. Differential Revision: D81299840 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162097 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
494878a11b
commit
5da573c42c
@ -453,6 +453,50 @@ def run(cnt):
|
||||
f(t(4, 2), t(2, 4))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_profile_merges(self):
|
||||
from torch._dynamo.pgo import auto_dynamic, merge_pgo_entry
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(ints, t_scalar, tensors):
|
||||
# arbitrary compute
|
||||
return ints[0] + ints[1], t_scalar + 1, [t + 1 for t in tensors]
|
||||
|
||||
# single static run
|
||||
f(
|
||||
[0, 2],
|
||||
torch.tensor(0),
|
||||
[
|
||||
torch.randn(2),
|
||||
torch.randn(2, 2),
|
||||
torch.randn(4, 4),
|
||||
],
|
||||
)
|
||||
# collect profiles
|
||||
profile = next(
|
||||
iter(torch._dynamo.pgo.get_code_state().values())
|
||||
).automatic_dynamic
|
||||
i0, i1 = profile["L['ints'][0]"], profile["L['ints'][1]"]
|
||||
ts = profile["L['t_scalar]"]
|
||||
t0, t1, t2 = (
|
||||
profile["L['tensors'][0]"],
|
||||
profile["L['tensors'][1]"],
|
||||
profile["L['tensors'][2]"],
|
||||
)
|
||||
# merging same scalar, or tensor into scalar -> no-op
|
||||
merge_pgo_entry(i0, i0)
|
||||
merge_pgo_entry(ts, i0)
|
||||
merge_pgo_entry(t0, i0)
|
||||
self.assertEqual(i0.scalar, 0)
|
||||
# merging different scalars -> dynamic
|
||||
merge_pgo_entry(i1, i0)
|
||||
self.assertEqual(i0.scalar, auto_dynamic)
|
||||
# merging different rank tensors -> static
|
||||
merge_pgo_entry(t0, t2)
|
||||
self.assertEqual(t2.size, (4, 4))
|
||||
# merging same rank tensors -> dynamic
|
||||
merge_pgo_entry(t1, t2)
|
||||
self.assertEqual(t2.size, (auto_dynamic, auto_dynamic))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -671,6 +671,16 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
|
||||
return code_state_str
|
||||
|
||||
|
||||
def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None:
|
||||
def rank(entry: FrameStateSizeEntry) -> int:
|
||||
if not isinstance(entry.size, tuple): # scalar
|
||||
return -1
|
||||
return len(entry.size)
|
||||
|
||||
if rank(src) == rank(dst): # both tensors same rank, or both scalars
|
||||
dst |= src
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class PGOCacheArtifact(CacheArtifact):
|
||||
@override
|
||||
@ -825,7 +835,9 @@ def add_extra_remote_code_state(cache_key: str) -> None:
|
||||
# where one entry might be 1-d, the other 2-d.
|
||||
# or if entries are of different types?
|
||||
# with local source naming, could be scalar vs. tensor
|
||||
_CODE_STATE[code_id].automatic_dynamic[src] |= entry
|
||||
merge_pgo_entry(
|
||||
entry, _CODE_STATE[code_id].automatic_dynamic[src]
|
||||
)
|
||||
else:
|
||||
_CODE_STATE[code_id] = state
|
||||
# log to tlparse
|
||||
|
Reference in New Issue
Block a user