Reverting due to concerns over silent unsoundness (skipped hooks) if users have directly added hooks dicts without using official torch APIs.
This reverts commit 26045336ca323fd27cff2a7340fe896117d5fb6e.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96242
Approved by: https://github.com/albanD
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks.
It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards. (But this observer change seems more involved...)
Idea: maintain a flag, and keep it up to date whenever adding or
removing hooks. Use the flag rather than dict checks to enter the call fast path.
- need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal.
- also need to handle the flag in ScriptModule which still uses the python call impl when called from python.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95931
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
closes#35643
This PR is mostly borrowed from #82042. Thanks @Padarn for implementing
the first version and debugging into the errors.
Based on the discussion in #82042 this PR adds a with_kwargs
argument to register_forward_pre_hook and register_forward_hook
methods. When the arg is set to true, the provided hook must accept
kwargs args. Under the hook, this PR adds a
`_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs`
set to keep track of which hooks accept kwargs.
Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89389
Approved by: https://github.com/soulitzer