mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Emit warning on global module hooks when calling using output of torch.compile(module)
(#152740)
When we do `torch.compile(module)`, we eventually end up returning a new `OptimizedModule` instance, whose `forward` method is the result of `torch.compile(mod.__call__)`, meaning it already captures all the extra logic (e.g., hook firing) for the compiled module. `OptimizedModule` also inherits `nn.module.__call__`, and thus has its own hook logic. This is useful for torchao, which injects module forward hooks to run in eager for quantization purposes. However, this might create unexpected behavior for global module hooks, because `torch.compile(module)` causes the hook to fire one extra time for `OptimizedModule`, when compared to eager. To preserve BC, we simply emit a warning for this behavior, and let users decide what to do. This is reasonable because the global module hooks are documented to be used for debugging/profiling purposes only. Fixes #149502 Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740 Approved by: https://github.com/anijain2305, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
b3dea0c0dd
commit
6765df052c
@ -889,6 +889,36 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_global_module_forward_pre_hook(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x - 1
|
||||
|
||||
counter = 0
|
||||
|
||||
def hook(mod, args):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
return args
|
||||
|
||||
x = torch.rand(18, 18)
|
||||
mod = Mod()
|
||||
compiled_mod = torch.compile(mod, backend="eager")
|
||||
|
||||
try:
|
||||
hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(hook)
|
||||
ref = mod(x)
|
||||
self.assertEqual(counter, 1)
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
r"Using `torch.compile\(module\)` when there are global hooks.*",
|
||||
):
|
||||
res = compiled_mod(x)
|
||||
self.assertEqual(counter, 3)
|
||||
self.assertEqual(ref, res)
|
||||
finally:
|
||||
hook_handle.remove()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -343,6 +343,19 @@ class OptimizedModule(torch.nn.Module):
|
||||
self._forward = self.forward
|
||||
self.forward = self._call_lazy_check
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if torch.nn.modules.module._has_any_global_hook():
|
||||
warnings.warn(
|
||||
"Using `torch.compile(module)` when there are global hooks on "
|
||||
"modules (e.g., from `register_module_forward_hook`); this will"
|
||||
" cause the hooks to fire an extra time for the "
|
||||
"`OptimizedModule` created by `torch.compile(module)`. If this "
|
||||
"causes undesired behavior, please try using `module.compile()`"
|
||||
", or use the per-module hooks instead",
|
||||
stacklevel=2,
|
||||
)
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self._orig_mod, self.dynamo_ctx))
|
||||
|
||||
|
@ -118,6 +118,18 @@ _global_forward_hooks: dict[int, Callable] = OrderedDict()
|
||||
_global_forward_hooks_always_called: dict[int, bool] = OrderedDict()
|
||||
_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict()
|
||||
|
||||
|
||||
def _has_any_global_hook():
|
||||
return (
|
||||
_global_backward_pre_hooks
|
||||
or _global_backward_hooks
|
||||
or _global_forward_pre_hooks
|
||||
or _global_forward_hooks
|
||||
or _global_forward_hooks_always_called
|
||||
or _global_forward_hooks_with_kwargs
|
||||
)
|
||||
|
||||
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user