[PGO] ignore extra PGO key if warm/cold cache present (#163810)

Summary: avoids PGO profile merges

Test Plan: test_pgo

Differential Revision: D83200714

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163810
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Pian Pawakapan
2025-09-25 07:16:02 +00:00
committed by PyTorch MergeBot
parent eb7f4e0004
commit 5f90e8c7ae
2 changed files with 16 additions and 78 deletions

View File

@ -443,59 +443,15 @@ def run(cnt):
f(t(2, 4), t(2, 2))
f(t(4, 2), t(2, 2))
# with default remote (dynamic x) + extra remote (dynamic y),
# we should be able to wobble x & y with no recompiles.
# with both default remote present, we ignore extra remote.
self.reset()
cnts.clear()
with torch.compiler.config.patch(pgo_extra_read_key="sticky_1"):
f(t(2, 2), t(2, 2))
f(t(2, 4), t(4, 2))
f(t(4, 2), t(2, 4))
f(t(6, 8), t(2, 2))
self.assertEqual(cnts.frame_count, 1)
def test_profile_merges(self):
from torch._dynamo.pgo import auto_dynamic, merge_pgo_entry
@torch.compile(backend="eager", fullgraph=True)
def f(ints, t_scalar, tensors):
# arbitrary compute
return ints[0] + ints[1], t_scalar + 1, [t + 1 for t in tensors]
# single static run
f(
[0, 2],
torch.tensor(0),
[
torch.randn(2),
torch.randn(2, 2),
torch.randn(4, 4),
],
)
# collect profiles
profile = next(
iter(torch._dynamo.pgo.get_code_state().values())
).automatic_dynamic
i0, i1 = profile["L['ints'][0]"], profile["L['ints'][1]"]
ts = profile["L['t_scalar]"]
t0, t1, t2 = (
profile["L['tensors'][0]"],
profile["L['tensors'][1]"],
profile["L['tensors'][2]"],
)
# merging same scalar, or tensor into scalar -> no-op
merge_pgo_entry(i0, i0)
merge_pgo_entry(ts, i0)
merge_pgo_entry(t0, i0)
self.assertEqual(i0.scalar, 0)
# merging different scalars -> dynamic
merge_pgo_entry(i1, i0)
self.assertEqual(i0.scalar, auto_dynamic)
# merging different rank tensors -> static
merge_pgo_entry(t0, t2)
self.assertEqual(t2.size, (4, 4))
# merging same rank tensors -> dynamic
merge_pgo_entry(t1, t2)
self.assertEqual(t2.size, (auto_dynamic, auto_dynamic))
f(t(2, 2), t(2, 4))
self.assertEqual(cnts.frame_count, 2)
if __name__ == "__main__":

View File

@ -671,16 +671,6 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
return code_state_str
def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None:
def rank(entry: FrameStateSizeEntry) -> int:
if not isinstance(entry.size, tuple): # scalar
return -1
return len(entry.size)
if rank(src) == rank(dst): # both tensors same rank, or both scalars
dst |= src
@CacheArtifactFactory.register
class PGOCacheArtifact(CacheArtifact):
@override
@ -805,7 +795,7 @@ def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeSt
return None
def add_extra_remote_code_state(cache_key: str) -> None:
def get_extra_remote_code_state(cache_key: str) -> None:
"""
Reads an additional PGO profile from the given cache key, and merges it with the default PGO profile.
"""
@ -815,34 +805,23 @@ def add_extra_remote_code_state(cache_key: str) -> None:
remote_cache = get_remote_cache()
if remote_cache is not None:
with dynamo_timed(
name := "pgo.add_extra_remote_code_state",
name := "pgo.get_extra_remote_code_state",
log_pt2_compile_event=True,
dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
):
CompileEventLogger.pt2_compile(name, cache_key=cache_key)
code_state = lookup_remote_cache_entry(remote_cache, cache_key)
log.info(
"add_extra_code_state %s hit, %d entries",
"get_extra_code_state %s hit, %d entries",
cache_key,
len(code_state) if code_state is not None else 0,
)
if code_state is not None:
# merge the code state into the current one
for code_id, state in code_state.items():
if code_id in _CODE_STATE:
for src, entry in state.automatic_dynamic.items():
# NOTE: maybe we need an "unsafe" merge to handle this,
# where one entry might be 1-d, the other 2-d.
# or if entries are of different types?
# with local source naming, could be scalar vs. tensor
merge_pgo_entry(
entry, _CODE_STATE[code_id].automatic_dynamic[src]
)
else:
_CODE_STATE[code_id] = state
assert not _CODE_STATE
_CODE_STATE = code_state
# log to tlparse
trace_structured_artifact(
"add_extra_remote_code_state",
"get_extra_remote_code_state",
"string",
lambda: render_code_state(code_state),
)
@ -867,11 +846,14 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
if local_code_state is None:
get_remote_code_state(cache_key)
# Attempt additional remote
if (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None:
# Attempt additional remote if neither local/default remote succeeded
if (
not _CODE_STATE
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
):
extra_read_key = get_extra_cache_key(sticky_read)
if extra_read_key is not None:
add_extra_remote_code_state(extra_read_key)
get_extra_remote_code_state(extra_read_key)
log.info("get_code_state using default")