Revert "Improve hooks ordering behavior (#85849)"

This reverts commit 049838f2496bd1d29e4e8292714acb0042cc706e.

Reverted https://github.com/pytorch/pytorch/pull/85849 on behalf of https://github.com/albanD due to fails internal build
This commit is contained in:
PyTorch MergeBot
2023-01-18 15:27:22 +00:00
parent 7f0d321d2e
commit e525f433e1
15 changed files with 104 additions and 387 deletions

View File

@ -659,8 +659,6 @@ static int THPVariable_clear(THPVariable* self) {
if (auto grad_acc =
torch::autograd::impl::try_get_grad_accumulator(tensor)) {
grad_acc->pre_hooks().clear();
grad_acc->tensor_pre_hooks().clear();
grad_acc->retains_grad_hooks().clear();
}
}
}
@ -1318,7 +1316,7 @@ int THPVariable_set_backwards_hooks(
torch::autograd::impl::clear_hooks(tensor);
if (obj) {
torch::autograd::impl::add_hook(
tensor, std::make_unique<PyFunctionTensorPreHook>(obj, 0));
tensor, std::make_shared<PyFunctionTensorPreHook>(obj, 0));
}
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
@ -2130,8 +2128,9 @@ static int THPVariable_subclass_traverse(
// object, which requires the GIL to be accessed. Note that this is only
// valid as long as user don't share non-owning references across
// different threads (which is crazy and should never be done).
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
if (tensor.use_count() == 1) {
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
if (autograd_meta) {
// Do NOT call grad_fn() here as that might trigger a recompute
const auto& grad_fn = autograd_meta->grad_fn_;
@ -2145,12 +2144,10 @@ static int THPVariable_subclass_traverse(
}
}
}
if (autograd_meta) {
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}