[dynamo] Emit warning on global module hooks when calling using output of torch.compile(module) (#152740)

When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.

`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.

However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.

To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.

Fixes #149502

Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305, https://github.com/zou3519
This commit is contained in:
Ryan Guo
2025-05-13 12:55:42 -07:00
committed by PyTorch MergeBot
parent b3dea0c0dd
commit 6765df052c
3 changed files with 55 additions and 0 deletions

View File

@ -889,6 +889,36 @@ class HooksTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
def test_global_module_forward_pre_hook(self):
class Mod(torch.nn.Module):
def forward(self, x):
return x - 1
counter = 0
def hook(mod, args):
nonlocal counter
counter += 1
return args
x = torch.rand(18, 18)
mod = Mod()
compiled_mod = torch.compile(mod, backend="eager")
try:
hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(hook)
ref = mod(x)
self.assertEqual(counter, 1)
with self.assertWarnsRegex(
UserWarning,
r"Using `torch.compile\(module\)` when there are global hooks.*",
):
res = compiled_mod(x)
self.assertEqual(counter, 3)
self.assertEqual(ref, res)
finally:
hook_handle.remove()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -343,6 +343,19 @@ class OptimizedModule(torch.nn.Module):
self._forward = self.forward
self.forward = self._call_lazy_check
def __call__(self, *args, **kwargs):
if torch.nn.modules.module._has_any_global_hook():
warnings.warn(
"Using `torch.compile(module)` when there are global hooks on "
"modules (e.g., from `register_module_forward_hook`); this will"
" cause the hooks to fire an extra time for the "
"`OptimizedModule` created by `torch.compile(module)`. If this "
"causes undesired behavior, please try using `module.compile()`"
", or use the per-module hooks instead",
stacklevel=2,
)
return super().__call__(*args, **kwargs)
def __reduce__(self):
return (self.__class__, (self._orig_mod, self.dynamo_ctx))

View File

@ -118,6 +118,18 @@ _global_forward_hooks: dict[int, Callable] = OrderedDict()
_global_forward_hooks_always_called: dict[int, bool] = OrderedDict()
_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict()
def _has_any_global_hook():
return (
_global_backward_pre_hooks
or _global_backward_hooks
or _global_forward_pre_hooks
or _global_forward_hooks
or _global_forward_hooks_always_called
or _global_forward_hooks_with_kwargs
)
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"