mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PGO] ignore extra PGO key if warm/cold cache present (#163810)
Summary: avoids PGO profile merges Test Plan: test_pgo Differential Revision: D83200714 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163810 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
eb7f4e0004
commit
5f90e8c7ae
@ -443,59 +443,15 @@ def run(cnt):
|
||||
f(t(2, 4), t(2, 2))
|
||||
f(t(4, 2), t(2, 2))
|
||||
|
||||
# with default remote (dynamic x) + extra remote (dynamic y),
|
||||
# we should be able to wobble x & y with no recompiles.
|
||||
# with both default remote present, we ignore extra remote.
|
||||
self.reset()
|
||||
cnts.clear()
|
||||
with torch.compiler.config.patch(pgo_extra_read_key="sticky_1"):
|
||||
f(t(2, 2), t(2, 2))
|
||||
f(t(2, 4), t(4, 2))
|
||||
f(t(4, 2), t(2, 4))
|
||||
f(t(6, 8), t(2, 2))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_profile_merges(self):
|
||||
from torch._dynamo.pgo import auto_dynamic, merge_pgo_entry
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(ints, t_scalar, tensors):
|
||||
# arbitrary compute
|
||||
return ints[0] + ints[1], t_scalar + 1, [t + 1 for t in tensors]
|
||||
|
||||
# single static run
|
||||
f(
|
||||
[0, 2],
|
||||
torch.tensor(0),
|
||||
[
|
||||
torch.randn(2),
|
||||
torch.randn(2, 2),
|
||||
torch.randn(4, 4),
|
||||
],
|
||||
)
|
||||
# collect profiles
|
||||
profile = next(
|
||||
iter(torch._dynamo.pgo.get_code_state().values())
|
||||
).automatic_dynamic
|
||||
i0, i1 = profile["L['ints'][0]"], profile["L['ints'][1]"]
|
||||
ts = profile["L['t_scalar]"]
|
||||
t0, t1, t2 = (
|
||||
profile["L['tensors'][0]"],
|
||||
profile["L['tensors'][1]"],
|
||||
profile["L['tensors'][2]"],
|
||||
)
|
||||
# merging same scalar, or tensor into scalar -> no-op
|
||||
merge_pgo_entry(i0, i0)
|
||||
merge_pgo_entry(ts, i0)
|
||||
merge_pgo_entry(t0, i0)
|
||||
self.assertEqual(i0.scalar, 0)
|
||||
# merging different scalars -> dynamic
|
||||
merge_pgo_entry(i1, i0)
|
||||
self.assertEqual(i0.scalar, auto_dynamic)
|
||||
# merging different rank tensors -> static
|
||||
merge_pgo_entry(t0, t2)
|
||||
self.assertEqual(t2.size, (4, 4))
|
||||
# merging same rank tensors -> dynamic
|
||||
merge_pgo_entry(t1, t2)
|
||||
self.assertEqual(t2.size, (auto_dynamic, auto_dynamic))
|
||||
f(t(2, 2), t(2, 4))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -671,16 +671,6 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
|
||||
return code_state_str
|
||||
|
||||
|
||||
def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None:
|
||||
def rank(entry: FrameStateSizeEntry) -> int:
|
||||
if not isinstance(entry.size, tuple): # scalar
|
||||
return -1
|
||||
return len(entry.size)
|
||||
|
||||
if rank(src) == rank(dst): # both tensors same rank, or both scalars
|
||||
dst |= src
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class PGOCacheArtifact(CacheArtifact):
|
||||
@override
|
||||
@ -805,7 +795,7 @@ def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeSt
|
||||
return None
|
||||
|
||||
|
||||
def add_extra_remote_code_state(cache_key: str) -> None:
|
||||
def get_extra_remote_code_state(cache_key: str) -> None:
|
||||
"""
|
||||
Reads an additional PGO profile from the given cache key, and merges it with the default PGO profile.
|
||||
"""
|
||||
@ -815,34 +805,23 @@ def add_extra_remote_code_state(cache_key: str) -> None:
|
||||
remote_cache = get_remote_cache()
|
||||
if remote_cache is not None:
|
||||
with dynamo_timed(
|
||||
name := "pgo.add_extra_remote_code_state",
|
||||
name := "pgo.get_extra_remote_code_state",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
|
||||
):
|
||||
CompileEventLogger.pt2_compile(name, cache_key=cache_key)
|
||||
code_state = lookup_remote_cache_entry(remote_cache, cache_key)
|
||||
log.info(
|
||||
"add_extra_code_state %s hit, %d entries",
|
||||
"get_extra_code_state %s hit, %d entries",
|
||||
cache_key,
|
||||
len(code_state) if code_state is not None else 0,
|
||||
)
|
||||
if code_state is not None:
|
||||
# merge the code state into the current one
|
||||
for code_id, state in code_state.items():
|
||||
if code_id in _CODE_STATE:
|
||||
for src, entry in state.automatic_dynamic.items():
|
||||
# NOTE: maybe we need an "unsafe" merge to handle this,
|
||||
# where one entry might be 1-d, the other 2-d.
|
||||
# or if entries are of different types?
|
||||
# with local source naming, could be scalar vs. tensor
|
||||
merge_pgo_entry(
|
||||
entry, _CODE_STATE[code_id].automatic_dynamic[src]
|
||||
)
|
||||
else:
|
||||
_CODE_STATE[code_id] = state
|
||||
assert not _CODE_STATE
|
||||
_CODE_STATE = code_state
|
||||
# log to tlparse
|
||||
trace_structured_artifact(
|
||||
"add_extra_remote_code_state",
|
||||
"get_extra_remote_code_state",
|
||||
"string",
|
||||
lambda: render_code_state(code_state),
|
||||
)
|
||||
@ -867,11 +846,14 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
|
||||
if local_code_state is None:
|
||||
get_remote_code_state(cache_key)
|
||||
|
||||
# Attempt additional remote
|
||||
if (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None:
|
||||
# Attempt additional remote if neither local/default remote succeeded
|
||||
if (
|
||||
not _CODE_STATE
|
||||
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
|
||||
):
|
||||
extra_read_key = get_extra_cache_key(sticky_read)
|
||||
if extra_read_key is not None:
|
||||
add_extra_remote_code_state(extra_read_key)
|
||||
get_extra_remote_code_state(extra_read_key)
|
||||
|
||||
log.info("get_code_state using default")
|
||||
|
||||
|
Reference in New Issue
Block a user