Improve hooks ordering behavior (#85849)

Addresses: https://github.com/pytorch/pytorch/issues/35802

Design doc: https://docs.google.com/document/d/19xSib7FFknRQ5f3ptGFUmiOt3BrgXSUlTQH2xMcZJYg/edit#

### Changes in this PR

#### Implementation
- We have now have 3 fields: pre_hooks, retains_grad_hooks, and tensor_pre_hooks so that we can more precisely define their ordering and when they are executed.
- Since retains grad uses an entirely new field, we cannot reuse the old retains grad, logic. We refactor retains grad to call directly into the variable.cpp logic. Other logic in variable.cpp that handle cpp hooks must also be updated.

#### Hooks ordering and execution:
- Defines pre-hooks registered on tensor to run before pre-hooks registered on grad_fn
- Updates pre-hooks registered on tensor to always run, even if they are the inputs= to .grad()
- Post hooks (and pre hooks) can now observe the modifications to gradient by the tensor pre hook

#### Retains grad hooks
- retains grad hooks always execute last, even if there are other tensor pre-hooks registered

#### Unchanged:
- pre_hooks registered to grad_fn aren't expected to execute if they are the inputs= to .grad()

Follow ups:
- simplify retains_grad field to not be a vector, since it always holds a single hook
- potentially merge capture hooks with tensor pre hooks, this would involve some additional refactoring since
- python hooks registered to tensor behavior on in-place is still wrong

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85849
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2023-01-11 12:33:29 -05:00
committed by PyTorch MergeBot
parent fb1427ea8f
commit 049838f249
15 changed files with 387 additions and 104 deletions

View File

@ -77,11 +77,25 @@ PyObject* THPCppFunction_call(
int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
auto& fn = *((THPCppFunction*)self)->cdata;
for (const auto& hook : fn.pre_hooks()) {
for (const auto& hook : fn.tensor_pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
// NOTE [retains_grad_hook PyObject traversal]
// In theory this shouldn't be necessary, because retains_grad_hooks should
// not contain any PyFunctionTensorPreHooks. The alternative is to have a
// check that actually guarantees this.
for (const auto& hook : fn.retains_grad_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
for (const auto& hook : fn.pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
for (const auto& hook : fn.post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
@ -152,7 +166,7 @@ PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) {
auto& fn = *((THPCppFunction*)self)->cdata;
std::unique_ptr<FunctionPreHook> hook(new PyFunctionTensorPreHook(
var->backward_hooks, THPVariable_Unpack(var).output_nr()));
fn.add_pre_hook(std::move(hook));
fn.add_tensor_pre_hook(std::move(hook));
Py_RETURN_NONE;
}