mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add logging for when inbuilt_inline_nn_modules will help with ID_MATCH guard triggered recompiles (#160592)
We add a logging around when an ID_MATCH guard is added at a place where inbuilt_inline_nn_modules would inline it. This is done with the aim of tagging recompiles that could be avoided by setting inbuilt_inline_nn_modules flag. It will help us log and track the flag's adoption and potentially quantify saving in the the number of recompiles. Differential Revision: D80075975 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160592 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
b26d2a9464
commit
052c441cf4
@ -4,9 +4,58 @@ from unittest.mock import patch
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo import config as dc
|
||||
|
||||
|
||||
class RecompileTests(torch._dynamo.test_case.TestCase):
|
||||
def test_inline_inbuilt_nn_modules_candidate(self):
|
||||
def hook_flag_on(guard_manager, f_locals, builder):
|
||||
self.assertTrue(
|
||||
"[inline-inbuilt-nn-modules-candidate]" not in str(guard_manager)
|
||||
)
|
||||
|
||||
def hook_flag_off(guard_manager, f_locals, builder):
|
||||
self.assertTrue(
|
||||
"[inline-inbuilt-nn-modules-candidate]" in str(guard_manager)
|
||||
)
|
||||
|
||||
class SubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sm1 = SubMod()
|
||||
self.sm2 = SubMod()
|
||||
|
||||
def forward(self, x):
|
||||
return self.sm1(x) + self.sm2(x)
|
||||
|
||||
try:
|
||||
from .utils import install_guard_manager_testing_hook
|
||||
except ImportError:
|
||||
from utils import install_guard_manager_testing_hook
|
||||
|
||||
with (
|
||||
install_guard_manager_testing_hook(hook_flag_on),
|
||||
dc.patch(inline_inbuilt_nn_modules=True),
|
||||
):
|
||||
mod = Mod()
|
||||
mod(torch.randn(2, 2))
|
||||
|
||||
with (
|
||||
install_guard_manager_testing_hook(hook_flag_off),
|
||||
dc.patch(inline_inbuilt_nn_modules=False),
|
||||
):
|
||||
mod = Mod()
|
||||
mod(torch.randn(2, 2))
|
||||
|
||||
def test_automatic_dynamic_reduce_recompiles(self):
|
||||
# Test the counterfactual, lots of recompiles without this config
|
||||
def foo(x, y):
|
||||
|
@ -500,6 +500,7 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_fx_remote_cache_hit_keys': None,
|
||||
'inductor_fx_remote_cache_miss_count': None,
|
||||
'inductor_fx_remote_cache_miss_keys': None,
|
||||
'inline_inbuilt_nn_modules_candidate': False,
|
||||
'is_forward': True,
|
||||
'is_runtime': False,
|
||||
'joint_graph_pass_time_us': 0,
|
||||
@ -583,6 +584,7 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_fx_remote_cache_hit_keys': None,
|
||||
'inductor_fx_remote_cache_miss_count': None,
|
||||
'inductor_fx_remote_cache_miss_keys': None,
|
||||
'inline_inbuilt_nn_modules_candidate': False,
|
||||
'is_forward': True,
|
||||
'is_runtime': False,
|
||||
'joint_graph_pass_time_us': 0,
|
||||
@ -677,6 +679,7 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_fx_remote_cache_hit_keys': None,
|
||||
'inductor_fx_remote_cache_miss_count': None,
|
||||
'inductor_fx_remote_cache_miss_keys': None,
|
||||
'inline_inbuilt_nn_modules_candidate': False,
|
||||
'is_forward': False,
|
||||
'is_runtime': False,
|
||||
'joint_graph_pass_time_us': None,
|
||||
@ -760,6 +763,7 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_fx_remote_cache_hit_keys': None,
|
||||
'inductor_fx_remote_cache_miss_count': None,
|
||||
'inductor_fx_remote_cache_miss_keys': None,
|
||||
'inline_inbuilt_nn_modules_candidate': False,
|
||||
'is_forward': False,
|
||||
'is_runtime': False,
|
||||
'joint_graph_pass_time_us': None,
|
||||
|
@ -1080,7 +1080,31 @@ def _compile(
|
||||
recompile_reason = (
|
||||
"Unable to find recompilation reasons" if not reasons else reasons[0]
|
||||
)
|
||||
metrics_context.update_outer({"recompile_reason": recompile_reason})
|
||||
# Recheck for recompilation, for when inline_inbuilt_nn_modules is set to False
|
||||
inline_inbuilt_nn_modules_candidate = False
|
||||
if not config.inline_inbuilt_nn_modules and frame:
|
||||
inbuilt_nn_reasons = get_and_maybe_log_recompilation_reasons(
|
||||
cache_entry, frame, skip_logging=True
|
||||
)
|
||||
inbuilt_nn_recompile_reason = (
|
||||
None if not inbuilt_nn_reasons else inbuilt_nn_reasons[0]
|
||||
)
|
||||
|
||||
if (
|
||||
inbuilt_nn_recompile_reason is not None
|
||||
and "[inline-inbuilt-nn-modules-candidate]"
|
||||
in inbuilt_nn_recompile_reason
|
||||
):
|
||||
inline_inbuilt_nn_modules_candidate = True
|
||||
|
||||
# Set if the recompile is a candidate for inline_inbuilt_nn_modules
|
||||
# regardless of whether inline_inbuilt_nn_modules is set or not
|
||||
metrics_context.update_outer(
|
||||
{
|
||||
"recompile_reason": recompile_reason,
|
||||
"inline_inbuilt_nn_modules_candidate": inline_inbuilt_nn_modules_candidate,
|
||||
}
|
||||
)
|
||||
|
||||
recompile_user_contexts = get_hook_for_recompile_user_context()
|
||||
if recompile_user_contexts:
|
||||
|
@ -767,11 +767,22 @@ def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str:
|
||||
|
||||
|
||||
def get_verbose_code_parts(
|
||||
code_parts: Union[str, list[str]], guard: Optional[Guard]
|
||||
code_parts: Union[str, list[str]],
|
||||
guard: Optional[Guard],
|
||||
recompile_hint: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
if not isinstance(code_parts, list):
|
||||
code_parts = [code_parts]
|
||||
return [get_verbose_code_part(code_part, guard) for code_part in code_parts]
|
||||
|
||||
verbose_code_parts = [
|
||||
get_verbose_code_part(code_part, guard) for code_part in code_parts
|
||||
]
|
||||
if recompile_hint:
|
||||
verbose_code_parts = [
|
||||
f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts
|
||||
]
|
||||
|
||||
return verbose_code_parts
|
||||
|
||||
|
||||
def convert_int_to_concrete_values(dim: Any) -> Optional[int]:
|
||||
@ -1932,12 +1943,14 @@ class GuardBuilder(GuardBuilderBase):
|
||||
get_verbose_code_parts(code, guard)
|
||||
)
|
||||
|
||||
def ID_MATCH(self, guard: Guard) -> None:
|
||||
def ID_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
|
||||
if self.serialization_mode == "save":
|
||||
raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.")
|
||||
return self.id_match_unchecked(guard)
|
||||
return self.id_match_unchecked(guard, recompile_hint)
|
||||
|
||||
def id_match_unchecked(self, guard: Guard) -> None:
|
||||
def id_match_unchecked(
|
||||
self, guard: Guard, recompile_hint: Optional[str] = None
|
||||
) -> None:
|
||||
# ___check_obj_id is same as `id(x) == y`
|
||||
if isinstance(guard.originating_source, TypeSource):
|
||||
# optional optimization to produce cleaner/faster guard code
|
||||
@ -1950,9 +1963,8 @@ class GuardBuilder(GuardBuilderBase):
|
||||
id_val = self.id_ref(val, guard.name)
|
||||
code = f"___check_obj_id({ref}, {id_val})"
|
||||
self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH")
|
||||
|
||||
self.get_guard_manager(guard).add_id_match_guard(
|
||||
id_val, get_verbose_code_parts(code, guard)
|
||||
id_val, get_verbose_code_parts(code, guard, recompile_hint)
|
||||
)
|
||||
|
||||
# Keep track of ID_MATCH'd objects. This will be used to modify the
|
||||
@ -2202,7 +2214,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
raise torch._dynamo.exc.PackageError(
|
||||
"NN_MODULE guard cannot be serialized."
|
||||
)
|
||||
self.ID_MATCH(guard)
|
||||
self.ID_MATCH(guard, "[inline-inbuilt-nn-modules-candidate]")
|
||||
val = self.get(guard.name)
|
||||
if hasattr(val, "training"):
|
||||
assert istype(val.training, bool)
|
||||
@ -4031,10 +4043,13 @@ def get_guard_fail_reason(
|
||||
code: types.CodeType,
|
||||
f_locals: dict[str, object],
|
||||
compile_id: CompileId,
|
||||
skip_logging: bool = False,
|
||||
) -> str:
|
||||
if isinstance(guard_manager, DeletedGuardManagerWrapper):
|
||||
return f"{compile_id}: {guard_manager.invalidation_reason}"
|
||||
reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id)
|
||||
if skip_logging:
|
||||
return reason_str
|
||||
guard_failures[orig_code_map[code]].append(reason_str)
|
||||
|
||||
try:
|
||||
@ -4051,7 +4066,9 @@ def get_guard_fail_reason(
|
||||
|
||||
|
||||
def get_and_maybe_log_recompilation_reasons(
|
||||
cache_entry: Optional[CacheEntry], frame: DynamoFrameType
|
||||
cache_entry: Optional[CacheEntry],
|
||||
frame: DynamoFrameType,
|
||||
skip_logging: bool = False,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Return the list of guard failure reasons using cache_entry.
|
||||
@ -4065,6 +4082,7 @@ def get_and_maybe_log_recompilation_reasons(
|
||||
cache_entry.code,
|
||||
frame.f_locals,
|
||||
cache_entry.compile_id,
|
||||
skip_logging,
|
||||
)
|
||||
if reason:
|
||||
reasons.append(reason)
|
||||
@ -4072,6 +4090,8 @@ def get_and_maybe_log_recompilation_reasons(
|
||||
|
||||
code = frame.f_code
|
||||
|
||||
if skip_logging:
|
||||
return reasons
|
||||
# at least one of "recompiles" or "recompiles_verbose" is enabled
|
||||
do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled()
|
||||
|
||||
|
@ -1362,6 +1362,7 @@ class CompilationMetrics:
|
||||
# the number of distinct type of params.
|
||||
param_count: Optional[int] = None
|
||||
recompile_user_contexts: Optional[set[str]] = None
|
||||
inline_inbuilt_nn_modules_candidate: Optional[bool] = False
|
||||
|
||||
@classmethod
|
||||
def create(cls, metrics: dict[str, Any]) -> CompilationMetrics:
|
||||
|
Reference in New Issue
Block a user