mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is an extension of [ModuleTracker](https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py) with added features and bug fixes. 1. Allows installing user-defined hooks to be called in pre-fw, post-fw, pre-bw and post-bw hooks of the ``ModTracker``. 2. Adds a function ``get_known_fqn`` that retrieves the fqn of the module as tracked by the ``ModTracker``. 3. Only registers the multi-grad hooks if we are in the forward pass. This is important because, a module's pre-fw and post-fw hooks get called in the backward during AC and we do not want to register multi-grad hooks in this case. 4. Sets the kwarg ``always_call=True`` for post-fw hooks, so that they are called post AC. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128508 Approved by: https://github.com/wanchaol
Note [TH abstraction violation] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TH/THC provide some hpp headers, which are proper C++ headers rather than C headers. These headers serve double duty as *internal implementation detail* headers, whose contents should largely not be used by external clients. Ideally, we would not install these headers at all; instead, you should use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`) to manipulate these structs. However, there are a few places in torch/csrc where we violate this abstraction. They are marked with a pointer to this note. Each of those sites will have to be refactored when we refactor the guts of THTensor and related structures.