Revert "Improve hooks ordering behavior (#85849)"

This reverts commit 049838f2496bd1d29e4e8292714acb0042cc706e.

Reverted https://github.com/pytorch/pytorch/pull/85849 on behalf of https://github.com/albanD due to fails internal build
This commit is contained in:
PyTorch MergeBot
2023-01-18 15:27:22 +00:00
parent 7f0d321d2e
commit e525f433e1
15 changed files with 104 additions and 387 deletions

View File

@ -77,22 +77,8 @@ PyObject* THPCppFunction_call(
int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
auto& fn = *((THPCppFunction*)self)->cdata;
for (const auto& hook : fn.tensor_pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
// NOTE [retains_grad_hook PyObject traversal]
// In theory this shouldn't be necessary, because retains_grad_hooks should
// not contain any PyFunctionTensorPreHooks. The alternative is to have a
// check that actually guarantees this.
for (const auto& hook : fn.retains_grad_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
for (const auto& hook : fn.pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
@ -167,7 +153,7 @@ PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) {
auto& fn = *((THPCppFunction*)self)->cdata;
std::unique_ptr<FunctionPreHook> hook(new PyFunctionTensorPreHook(
var->backward_hooks, THPVariable_Unpack(var).output_nr()));
fn.add_tensor_pre_hook(std::move(hook));
fn.add_pre_hook(std::move(hook));
Py_RETURN_NONE;
}