[PGO] log missing sources in allowlist (#164881)

Summary:
- logs missing dynamic sources
- emits MLHub insight only on size mismatch recompiles

Test Plan: test_pgo

Differential Revision: D84098898

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164881
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Pian Pawakapan
2025-10-09 04:39:07 +00:00
committed by PyTorch MergeBot
parent 7b691546d2
commit 1f73b96668
3 changed files with 62 additions and 18 deletions

View File

@ -62,6 +62,13 @@ class PgoTest(torch._dynamo.test_case.TestCase):
force_nn_module_property_static_shapes=False,
)
def test_whitelist_suggestion(self):
from torch._dynamo.pgo import (
_collect_dynamic_sources,
_collect_missing_sources,
get_code_state,
render_code_state,
)
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
@ -83,15 +90,19 @@ class PgoTest(torch._dynamo.test_case.TestCase):
]
def check_whitelist(sources_):
state = torch._dynamo.pgo.render_code_state(
torch._dynamo.pgo.get_code_state()
)
state = render_code_state(get_code_state())
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(
1
)
for src in sources_:
self.assertTrue(src in whitelist)
def check_num_missing_whitelist(expected):
frame_state = next(iter(get_code_state().values()))
all_dynamic_sources = _collect_dynamic_sources(frame_state)
missing_whitelist = _collect_missing_sources(all_dynamic_sources)
self.assertEqual(len(missing_whitelist), expected)
# check growing whitelist
f = Foo()
f(torch.randn(2, 4), torch.randn(4))
@ -107,11 +118,13 @@ class PgoTest(torch._dynamo.test_case.TestCase):
f.attr = torch.randn(8)
f(torch.randn(8, 8), torch.randn(8))
check_whitelist(sources)
check_num_missing_whitelist(5)
# now use suggested whitelist
self.reset()
cnts.clear()
state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state())
code_state = get_code_state()
state = render_code_state(code_state)
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1)
with torch.compiler.config.patch(dynamic_sources=whitelist):
f = Foo()
@ -121,6 +134,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
f.attr = torch.randn(8)
f(torch.randn(8, 8), torch.randn(8))
self.assertEqual(cnts.frame_count, 1)
check_num_missing_whitelist(0)
def test_no_empty_graph_allowlist(self):
@torch._dynamo.disable

View File

@ -123,7 +123,11 @@ from .guards import (
)
from .hooks import Hooks
from .output_graph import DynamoTracerOutput, OutputGraphCommon
from .pgo import log_frame_dynamic_whitelist, put_code_state
from .pgo import (
_log_size_mismatch_recompile,
log_frame_dynamic_whitelist,
put_code_state,
)
from .replay_record import ExecutionRecord
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .symbolic_convert import (
@ -1591,6 +1595,8 @@ def _compile(
and output_graph.has_outputs()
):
log_frame_dynamic_whitelist(code)
if recompile_reason and "size mismatch at index" in recompile_reason:
_log_size_mismatch_recompile()
return guarded_code
except Exception as e:

View File

@ -175,6 +175,7 @@ class CodeState:
_INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
_LOGGED_DYNAMIC_ALLOWLIST: bool = False
_KNOWN_DYNAMIC_SOURCES: set[str] = set()
@dataclasses.dataclass(frozen=True)
@ -630,26 +631,49 @@ def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]:
return dynamic_sources
def _collect_missing_sources(all_sources: OrderedSet[str]) -> OrderedSet[str]:
from torch._dynamo.variables.builder import is_dynamic_source
global _KNOWN_DYNAMIC_SOURCES
missing_sources: OrderedSet[str] = OrderedSet()
for src in all_sources:
if src in _KNOWN_DYNAMIC_SOURCES:
continue
elif is_dynamic_source(src):
_KNOWN_DYNAMIC_SOURCES.add(src)
continue
missing_sources.add(src)
return missing_sources
def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None:
global _LOGGED_DYNAMIC_ALLOWLIST
global _KNOWN_DYNAMIC_SOURCES
code_id = CodeId.make(f_code)
frame_state = get_code_state()[code_id]
frame_whitelist = ",".join(_collect_dynamic_sources(frame_state))
all_dynamic_sources = _collect_dynamic_sources(frame_state)
frame_whitelist = ",".join(all_dynamic_sources)
missing_whitelist = ",".join(_collect_missing_sources(all_dynamic_sources))
if frame_whitelist:
with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True):
CompileEventLogger.pt2_compile(
name, recompile_dynamic_whitelist=frame_whitelist
name,
recompile_dynamic_whitelist=frame_whitelist,
missing_dynamic_whitelist=missing_whitelist,
)
if not _LOGGED_DYNAMIC_ALLOWLIST:
torch._utils_internal.add_mlhub_insight(
category="dynamic_shapes_analysis",
insight="Dynamic shape recompilation detected",
insight_description="PGO detected a recompilation due to dynamic shapes. \
Please follow the instruction from the action link to reduce \
recompilation overhead.",
)
# add mlhub insight only once per rank
_LOGGED_DYNAMIC_ALLOWLIST = True
def _log_size_mismatch_recompile() -> None:
global _LOGGED_DYNAMIC_ALLOWLIST
if not _LOGGED_DYNAMIC_ALLOWLIST:
torch._utils_internal.add_mlhub_insight(
category="dynamic_shapes_analysis",
insight="Dynamic shape recompilation detected",
insight_description="PGO detected a recompilation due to dynamic shapes. \
Please follow the instruction from the action link to reduce \
recompilation overhead.",
)
# add mlhub insight only once per rank
_LOGGED_DYNAMIC_ALLOWLIST = True
def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: