Add logging for when inbuilt_inline_nn_modules will help with ID_MATCH guard triggered recompiles (#160592)

We add a logging around when an ID_MATCH guard is added at a place where inbuilt_inline_nn_modules would inline it. This is done with the aim of tagging recompiles that could be avoided by setting inbuilt_inline_nn_modules flag.
It will help us log and track the flag's adoption and potentially quantify saving in the the number of recompiles.

Differential Revision: D80075975

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160592
Approved by: https://github.com/anijain2305
This commit is contained in:
Prajesh Praveen Anchalia
2025-08-15 17:09:39 +00:00
committed by PyTorch MergeBot
parent b26d2a9464
commit 052c441cf4
5 changed files with 108 additions and 10 deletions

View File

@ -4,9 +4,58 @@ from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config as dc
class RecompileTests(torch._dynamo.test_case.TestCase):
def test_inline_inbuilt_nn_modules_candidate(self):
def hook_flag_on(guard_manager, f_locals, builder):
self.assertTrue(
"[inline-inbuilt-nn-modules-candidate]" not in str(guard_manager)
)
def hook_flag_off(guard_manager, f_locals, builder):
self.assertTrue(
"[inline-inbuilt-nn-modules-candidate]" in str(guard_manager)
)
class SubMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
@torch.compile(backend="eager")
def forward(self, x):
return self.linear(x)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.sm1 = SubMod()
self.sm2 = SubMod()
def forward(self, x):
return self.sm1(x) + self.sm2(x)
try:
from .utils import install_guard_manager_testing_hook
except ImportError:
from utils import install_guard_manager_testing_hook
with (
install_guard_manager_testing_hook(hook_flag_on),
dc.patch(inline_inbuilt_nn_modules=True),
):
mod = Mod()
mod(torch.randn(2, 2))
with (
install_guard_manager_testing_hook(hook_flag_off),
dc.patch(inline_inbuilt_nn_modules=False),
):
mod = Mod()
mod(torch.randn(2, 2))
def test_automatic_dynamic_reduce_recompiles(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):

View File

@ -500,6 +500,7 @@ class TestDynamoTimed(TestCase):
'inductor_fx_remote_cache_hit_keys': None,
'inductor_fx_remote_cache_miss_count': None,
'inductor_fx_remote_cache_miss_keys': None,
'inline_inbuilt_nn_modules_candidate': False,
'is_forward': True,
'is_runtime': False,
'joint_graph_pass_time_us': 0,
@ -583,6 +584,7 @@ class TestDynamoTimed(TestCase):
'inductor_fx_remote_cache_hit_keys': None,
'inductor_fx_remote_cache_miss_count': None,
'inductor_fx_remote_cache_miss_keys': None,
'inline_inbuilt_nn_modules_candidate': False,
'is_forward': True,
'is_runtime': False,
'joint_graph_pass_time_us': 0,
@ -677,6 +679,7 @@ class TestDynamoTimed(TestCase):
'inductor_fx_remote_cache_hit_keys': None,
'inductor_fx_remote_cache_miss_count': None,
'inductor_fx_remote_cache_miss_keys': None,
'inline_inbuilt_nn_modules_candidate': False,
'is_forward': False,
'is_runtime': False,
'joint_graph_pass_time_us': None,
@ -760,6 +763,7 @@ class TestDynamoTimed(TestCase):
'inductor_fx_remote_cache_hit_keys': None,
'inductor_fx_remote_cache_miss_count': None,
'inductor_fx_remote_cache_miss_keys': None,
'inline_inbuilt_nn_modules_candidate': False,
'is_forward': False,
'is_runtime': False,
'joint_graph_pass_time_us': None,

View File

@ -1080,7 +1080,31 @@ def _compile(
recompile_reason = (
"Unable to find recompilation reasons" if not reasons else reasons[0]
)
metrics_context.update_outer({"recompile_reason": recompile_reason})
# Recheck for recompilation, for when inline_inbuilt_nn_modules is set to False
inline_inbuilt_nn_modules_candidate = False
if not config.inline_inbuilt_nn_modules and frame:
inbuilt_nn_reasons = get_and_maybe_log_recompilation_reasons(
cache_entry, frame, skip_logging=True
)
inbuilt_nn_recompile_reason = (
None if not inbuilt_nn_reasons else inbuilt_nn_reasons[0]
)
if (
inbuilt_nn_recompile_reason is not None
and "[inline-inbuilt-nn-modules-candidate]"
in inbuilt_nn_recompile_reason
):
inline_inbuilt_nn_modules_candidate = True
# Set if the recompile is a candidate for inline_inbuilt_nn_modules
# regardless of whether inline_inbuilt_nn_modules is set or not
metrics_context.update_outer(
{
"recompile_reason": recompile_reason,
"inline_inbuilt_nn_modules_candidate": inline_inbuilt_nn_modules_candidate,
}
)
recompile_user_contexts = get_hook_for_recompile_user_context()
if recompile_user_contexts:

View File

@ -767,11 +767,22 @@ def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str:
def get_verbose_code_parts(
code_parts: Union[str, list[str]], guard: Optional[Guard]
code_parts: Union[str, list[str]],
guard: Optional[Guard],
recompile_hint: Optional[str] = None,
) -> list[str]:
if not isinstance(code_parts, list):
code_parts = [code_parts]
return [get_verbose_code_part(code_part, guard) for code_part in code_parts]
verbose_code_parts = [
get_verbose_code_part(code_part, guard) for code_part in code_parts
]
if recompile_hint:
verbose_code_parts = [
f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts
]
return verbose_code_parts
def convert_int_to_concrete_values(dim: Any) -> Optional[int]:
@ -1932,12 +1943,14 @@ class GuardBuilder(GuardBuilderBase):
get_verbose_code_parts(code, guard)
)
def ID_MATCH(self, guard: Guard) -> None:
def ID_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
if self.serialization_mode == "save":
raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.")
return self.id_match_unchecked(guard)
return self.id_match_unchecked(guard, recompile_hint)
def id_match_unchecked(self, guard: Guard) -> None:
def id_match_unchecked(
self, guard: Guard, recompile_hint: Optional[str] = None
) -> None:
# ___check_obj_id is same as `id(x) == y`
if isinstance(guard.originating_source, TypeSource):
# optional optimization to produce cleaner/faster guard code
@ -1950,9 +1963,8 @@ class GuardBuilder(GuardBuilderBase):
id_val = self.id_ref(val, guard.name)
code = f"___check_obj_id({ref}, {id_val})"
self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH")
self.get_guard_manager(guard).add_id_match_guard(
id_val, get_verbose_code_parts(code, guard)
id_val, get_verbose_code_parts(code, guard, recompile_hint)
)
# Keep track of ID_MATCH'd objects. This will be used to modify the
@ -2202,7 +2214,7 @@ class GuardBuilder(GuardBuilderBase):
raise torch._dynamo.exc.PackageError(
"NN_MODULE guard cannot be serialized."
)
self.ID_MATCH(guard)
self.ID_MATCH(guard, "[inline-inbuilt-nn-modules-candidate]")
val = self.get(guard.name)
if hasattr(val, "training"):
assert istype(val.training, bool)
@ -4031,10 +4043,13 @@ def get_guard_fail_reason(
code: types.CodeType,
f_locals: dict[str, object],
compile_id: CompileId,
skip_logging: bool = False,
) -> str:
if isinstance(guard_manager, DeletedGuardManagerWrapper):
return f"{compile_id}: {guard_manager.invalidation_reason}"
reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id)
if skip_logging:
return reason_str
guard_failures[orig_code_map[code]].append(reason_str)
try:
@ -4051,7 +4066,9 @@ def get_guard_fail_reason(
def get_and_maybe_log_recompilation_reasons(
cache_entry: Optional[CacheEntry], frame: DynamoFrameType
cache_entry: Optional[CacheEntry],
frame: DynamoFrameType,
skip_logging: bool = False,
) -> list[str]:
"""
Return the list of guard failure reasons using cache_entry.
@ -4065,6 +4082,7 @@ def get_and_maybe_log_recompilation_reasons(
cache_entry.code,
frame.f_locals,
cache_entry.compile_id,
skip_logging,
)
if reason:
reasons.append(reason)
@ -4072,6 +4090,8 @@ def get_and_maybe_log_recompilation_reasons(
code = frame.f_code
if skip_logging:
return reasons
# at least one of "recompiles" or "recompiles_verbose" is enabled
do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled()

View File

@ -1362,6 +1362,7 @@ class CompilationMetrics:
# the number of distinct type of params.
param_count: Optional[int] = None
recompile_user_contexts: Optional[set[str]] = None
inline_inbuilt_nn_modules_candidate: Optional[bool] = False
@classmethod
def create(cls, metrics: dict[str, Any]) -> CompilationMetrics: