mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PGO] log missing sources in allowlist (#164881)
Summary: - logs missing dynamic sources - emits MLHub insight only on size mismatch recompiles Test Plan: test_pgo Differential Revision: D84098898 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164881 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
7b691546d2
commit
1f73b96668
@ -62,6 +62,13 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||
force_nn_module_property_static_shapes=False,
|
||||
)
|
||||
def test_whitelist_suggestion(self):
|
||||
from torch._dynamo.pgo import (
|
||||
_collect_dynamic_sources,
|
||||
_collect_missing_sources,
|
||||
get_code_state,
|
||||
render_code_state,
|
||||
)
|
||||
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@ -83,15 +90,19 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||
]
|
||||
|
||||
def check_whitelist(sources_):
|
||||
state = torch._dynamo.pgo.render_code_state(
|
||||
torch._dynamo.pgo.get_code_state()
|
||||
)
|
||||
state = render_code_state(get_code_state())
|
||||
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(
|
||||
1
|
||||
)
|
||||
for src in sources_:
|
||||
self.assertTrue(src in whitelist)
|
||||
|
||||
def check_num_missing_whitelist(expected):
|
||||
frame_state = next(iter(get_code_state().values()))
|
||||
all_dynamic_sources = _collect_dynamic_sources(frame_state)
|
||||
missing_whitelist = _collect_missing_sources(all_dynamic_sources)
|
||||
self.assertEqual(len(missing_whitelist), expected)
|
||||
|
||||
# check growing whitelist
|
||||
f = Foo()
|
||||
f(torch.randn(2, 4), torch.randn(4))
|
||||
@ -107,11 +118,13 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||
f.attr = torch.randn(8)
|
||||
f(torch.randn(8, 8), torch.randn(8))
|
||||
check_whitelist(sources)
|
||||
check_num_missing_whitelist(5)
|
||||
|
||||
# now use suggested whitelist
|
||||
self.reset()
|
||||
cnts.clear()
|
||||
state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state())
|
||||
code_state = get_code_state()
|
||||
state = render_code_state(code_state)
|
||||
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1)
|
||||
with torch.compiler.config.patch(dynamic_sources=whitelist):
|
||||
f = Foo()
|
||||
@ -121,6 +134,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||
f.attr = torch.randn(8)
|
||||
f(torch.randn(8, 8), torch.randn(8))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
check_num_missing_whitelist(0)
|
||||
|
||||
def test_no_empty_graph_allowlist(self):
|
||||
@torch._dynamo.disable
|
||||
|
@ -123,7 +123,11 @@ from .guards import (
|
||||
)
|
||||
from .hooks import Hooks
|
||||
from .output_graph import DynamoTracerOutput, OutputGraphCommon
|
||||
from .pgo import log_frame_dynamic_whitelist, put_code_state
|
||||
from .pgo import (
|
||||
_log_size_mismatch_recompile,
|
||||
log_frame_dynamic_whitelist,
|
||||
put_code_state,
|
||||
)
|
||||
from .replay_record import ExecutionRecord
|
||||
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .symbolic_convert import (
|
||||
@ -1591,6 +1595,8 @@ def _compile(
|
||||
and output_graph.has_outputs()
|
||||
):
|
||||
log_frame_dynamic_whitelist(code)
|
||||
if recompile_reason and "size mismatch at index" in recompile_reason:
|
||||
_log_size_mismatch_recompile()
|
||||
|
||||
return guarded_code
|
||||
except Exception as e:
|
||||
|
@ -175,6 +175,7 @@ class CodeState:
|
||||
_INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
|
||||
_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
|
||||
_LOGGED_DYNAMIC_ALLOWLIST: bool = False
|
||||
_KNOWN_DYNAMIC_SOURCES: set[str] = set()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -630,26 +631,49 @@ def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]:
|
||||
return dynamic_sources
|
||||
|
||||
|
||||
def _collect_missing_sources(all_sources: OrderedSet[str]) -> OrderedSet[str]:
|
||||
from torch._dynamo.variables.builder import is_dynamic_source
|
||||
|
||||
global _KNOWN_DYNAMIC_SOURCES
|
||||
missing_sources: OrderedSet[str] = OrderedSet()
|
||||
for src in all_sources:
|
||||
if src in _KNOWN_DYNAMIC_SOURCES:
|
||||
continue
|
||||
elif is_dynamic_source(src):
|
||||
_KNOWN_DYNAMIC_SOURCES.add(src)
|
||||
continue
|
||||
missing_sources.add(src)
|
||||
return missing_sources
|
||||
|
||||
|
||||
def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None:
|
||||
global _LOGGED_DYNAMIC_ALLOWLIST
|
||||
global _KNOWN_DYNAMIC_SOURCES
|
||||
code_id = CodeId.make(f_code)
|
||||
frame_state = get_code_state()[code_id]
|
||||
frame_whitelist = ",".join(_collect_dynamic_sources(frame_state))
|
||||
all_dynamic_sources = _collect_dynamic_sources(frame_state)
|
||||
frame_whitelist = ",".join(all_dynamic_sources)
|
||||
missing_whitelist = ",".join(_collect_missing_sources(all_dynamic_sources))
|
||||
if frame_whitelist:
|
||||
with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True):
|
||||
CompileEventLogger.pt2_compile(
|
||||
name, recompile_dynamic_whitelist=frame_whitelist
|
||||
name,
|
||||
recompile_dynamic_whitelist=frame_whitelist,
|
||||
missing_dynamic_whitelist=missing_whitelist,
|
||||
)
|
||||
if not _LOGGED_DYNAMIC_ALLOWLIST:
|
||||
torch._utils_internal.add_mlhub_insight(
|
||||
category="dynamic_shapes_analysis",
|
||||
insight="Dynamic shape recompilation detected",
|
||||
insight_description="PGO detected a recompilation due to dynamic shapes. \
|
||||
Please follow the instruction from the action link to reduce \
|
||||
recompilation overhead.",
|
||||
)
|
||||
# add mlhub insight only once per rank
|
||||
_LOGGED_DYNAMIC_ALLOWLIST = True
|
||||
|
||||
|
||||
def _log_size_mismatch_recompile() -> None:
|
||||
global _LOGGED_DYNAMIC_ALLOWLIST
|
||||
if not _LOGGED_DYNAMIC_ALLOWLIST:
|
||||
torch._utils_internal.add_mlhub_insight(
|
||||
category="dynamic_shapes_analysis",
|
||||
insight="Dynamic shape recompilation detected",
|
||||
insight_description="PGO detected a recompilation due to dynamic shapes. \
|
||||
Please follow the instruction from the action link to reduce \
|
||||
recompilation overhead.",
|
||||
)
|
||||
# add mlhub insight only once per rank
|
||||
_LOGGED_DYNAMIC_ALLOWLIST = True
|
||||
|
||||
|
||||
def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
|
||||
|
Reference in New Issue
Block a user