diff --git a/test/test_autograd.py b/test/test_autograd.py index f42ec69944a9..001fb2302a5d 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -889,166 +889,6 @@ class TestAutograd(TestCase): with self.assertRaisesRegex(ValueError, "Expected allow_unused to be True or not passed when"): torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True) - def test_post_accumulate_grad_hook_on_non_leaf(self): - def hook(tensor): - tensor.sub_(1.) - leaf = torch.rand(3, requires_grad=True) - non_leaf = 2. * leaf - - with self.assertRaisesRegex( - RuntimeError, - "post accumulate grad hooks cannot be registered on non-leaf tensors"): - non_leaf.register_post_accumulate_grad_hook(hook) - - def test_post_accumulate_grad_hook_multiple_hooks(self): - def hook1(tensor): - tensor.sub_(tensor.grad) - - def hook2(tensor): - tensor.mul_(4.) - tensor = torch.rand(3, requires_grad=True) - tensor_ref = tensor.clone().detach() - tensor.register_post_accumulate_grad_hook(hook1) - tensor.register_post_accumulate_grad_hook(hook2) - sum = tensor.sum() - sum.backward() - # both hooks should be called, in order - self.assertEqual(4. * (tensor_ref - 1.), tensor) - - def test_post_accumulate_grad_hook_multiple_tensors(self): - def hook(tensor): - tensor.sub_(tensor.grad) - tensor1 = torch.rand(3, requires_grad=True) - tensor1_ref = tensor1.clone().detach() - tensor2 = torch.rand(5, requires_grad=True) - tensor2_ref = tensor2.clone().detach() - tensor1.register_post_accumulate_grad_hook(hook) - tensor2.register_post_accumulate_grad_hook(hook) - tensor1.sum().backward() - tensor2.sum().backward() - # both tensors should have been modified - self.assertEqual(tensor1_ref - 1., tensor1) - self.assertEqual(tensor2_ref - 1., tensor2) - - def test_post_accumulate_grad_hook_returns_not_None(self): - def bad_hook(tensor): - return tensor.grad - tensor = torch.rand(2, 3, requires_grad=True) - tensor.register_post_accumulate_grad_hook(bad_hook) - # should error! - with self.assertRaisesRegex(RuntimeError, "hooks should return None."): - tensor.sum().backward() - - def test_post_accumulate_grad_hook_e2e(self): - def setup_optim_in_bwd(model): - optims = {} - handles = [] - - def optim_step_hook(param): - optims[param].step() - optims[param].zero_grad() - - for p in model.parameters(): - optims[p] = torch.optim.Adam([p]) - handles.append(p.register_post_accumulate_grad_hook(optim_step_hook)) - - return handles - - model = torch.nn.Linear(3, 2) - input = torch.rand(2, 3) - handles = setup_optim_in_bwd(model) - - # make a copy for reference - model_copy = deepcopy(model) - optim_copy = torch.optim.Adam(model_copy.parameters()) - - iters = 5 - - for _ in range(iters): - loss = model(input).sum() - loss.backward() - - loss_copy = model_copy(input).sum() - loss_copy.backward() - optim_copy.step() - optim_copy.zero_grad() - - params_copy = [] # freeze a copy of the params to compare later - for p_reference, p in zip(model_copy.parameters(), model.parameters()): - self.assertEqual(p_reference, p) - params_copy.append(p_reference.clone().detach()) - - # After removing the handle, the model should no longer update. - for h in handles: - h.remove() - - for _ in range(iters): - loss = model(input).sum() - loss.backward() - - loss_copy = model_copy(input).sum() - loss_copy.backward() - optim_copy.step() - optim_copy.zero_grad() - - for p_static, p_reference, p in zip(params_copy, model_copy.parameters(), model.parameters()): - self.assertEqual(p_static, p) - self.assertNotEqual(p_reference, p) - - def test_post_accumulate_grad_hook_gets_cleaned_up(self): - - def fun_stuff_with_hook(): - thing_to_put_in_hook = torch.rand(3) - - def hook(tensor): - tensor.sub_(tensor.grad) - tensor.add_(thing_to_put_in_hook) - - tensor = torch.rand(3, requires_grad=True) - tensor.register_post_accumulate_grad_hook(hook) - tensor.sum().backward() - ref = weakref.ref(thing_to_put_in_hook) - gc.collect() - return tensor, ref - - with disable_gc(): - tensor, ref = fun_stuff_with_hook() - self.assertIsNotNone(ref()) # thing_to_put_in_hook should be kept alive by tensor - - del tensor - gc.collect() - self.assertIsNone(ref()) # thing_to_put_in_hook should be cleaned - - def test_post_accumulate_grad_hook_ordering(self): - tensor = torch.rand(3, requires_grad=True) - - def pre_hook(grad): - return grad.sub(2.) - - def acc_grad_node_pre_hook(grad_out): - return (grad_out[0].div(5.),) - - def post_acc_grad_hook(tensor): - tensor.grad.add_(0.5) - - def acc_grad_node_post_hook(grad_in, grad_out): - tensor.grad = grad_out[0].mul(10) - - acc_grad = tensor.view_as(tensor).grad_fn.next_functions[0][0] - tensor.register_hook(pre_hook) - acc_grad.register_prehook(acc_grad_node_pre_hook) - tensor.register_post_accumulate_grad_hook(post_acc_grad_hook) - acc_grad.register_hook(acc_grad_node_post_hook) - tensor.sum().backward() - - # the hooks should run in the order of: - # 1. tensor prehook - # 2. acc_grad prehook - # 3. tensor post acc_grad hook - # 4. acc_grad posthook - # so that would be ((1 - 2) / 5 + 0.5) * 10 = 3 - self.assertEqual(torch.tensor([3., 3., 3.]), tensor.grad) - def test_hook_with_no_name(self): # Create a hook that do not have a __name__ attribute class MyHookClass: diff --git a/torch/_tensor.py b/torch/_tensor.py index 8d135b4dcb87..a5ebd04a1354 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -530,7 +530,7 @@ class Tensor(torch._C._TensorBase): return handle_torch_function(Tensor.register_hook, (self,), self, hook) if not self.requires_grad: raise RuntimeError( - "cannot register a hook on a tensor that doesn't require gradient" + "cannot register a hook on a tensor that " "doesn't require gradient" ) if self._backward_hooks is None: self._backward_hooks = OrderedDict() @@ -540,62 +540,6 @@ class Tensor(torch._C._TensorBase): self._backward_hooks[handle.id] = hook return handle - def register_post_accumulate_grad_hook(self, hook): - r"""Registers a backward hook that runs after grad accumulation. - - The hook will be called after all gradients for a tensor have been accumulated, - meaning that the .grad field has been updated on that tensor. The post - accumulate grad hook is ONLY applicable for leaf tensors (tensors without a - .grad_fn field). Registering this hook on a non-leaf tensor will error! - - The hook should have the following signature:: - - hook(param: Tensor) -> None - - Note that, unlike other autograd hooks, this hook operates on the tensor - that requires grad and not the grad itself. The hook can in-place modify - and access its Tensor argument, including its .grad field. - - This function returns a handle with a method ``handle.remove()`` - that removes the hook from the module. - - .. note:: - See :ref:`backward-hooks-execution` for more information on how when this hook - is executed, and how its execution is ordered relative to other hooks. Since - this hook runs during the backward pass, it will run in no_grad mode (unless - create_graph is True). You can use torch.enable_grad() to re-enable autograd - within the hook if you need it. - - Example:: - - >>> v = torch.tensor([0., 0., 0.], requires_grad=True) - >>> lr = 0.01 - >>> # simulate a simple SGD update - >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) - >>> v.backward(torch.tensor([1., 2., 3.])) - >>> v - tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) - - >>> h.remove() # removes the hook - """ - if has_torch_function_unary(self): - return handle_torch_function( - Tensor.register_post_accumulate_grad_hook, (self,), self, hook - ) - if not self.requires_grad: - raise RuntimeError( - "cannot register a hook on a tensor that doesn't require gradient" - ) - if self.grad_fn is not None: - raise RuntimeError( - "post accumulate grad hooks cannot be registered on non-leaf tensors" - ) - if self._post_accumulate_grad_hooks is None: - self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict() - handle = hooks.RemovableHandle(self._post_accumulate_grad_hooks) - self._post_accumulate_grad_hooks[handle.id] = hook - return handle - def reinforce(self, reward): def trim(str): return "\n".join([line.strip() for line in str.split("\n")]) diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 9b0c3e3fd27e..d8066947f652 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -523,12 +523,6 @@ struct TORCH_API Node : std::enable_shared_from_this { return tensor_pre_hooks_; } - virtual std::unique_ptr& - tensor_post_acc_grad_hooks() noexcept { - static std::unique_ptr empty = nullptr; - return empty; - } - std::unordered_map>& retains_grad_hooks() noexcept { return retains_grad_hooks_; diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 21c9868940af..477c6ec1bac2 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -42,17 +42,5 @@ struct TORCH_API FunctionPostHook { } }; -struct TORCH_API PostAccumulateGradHook { - virtual ~PostAccumulateGradHook() = default; - virtual void operator()(const Variable& tensor) = 0; - // only implemented for python hooks on nodes, registers hook with compiled - // autograd - virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { - throw std::runtime_error( - std::string("not yet implemented for compiled autograd: ") + - typeid(*this).name()); - } -}; - } // namespace autograd } // namespace torch diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp index c7f0923752c9..5911e04d9f07 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.cpp +++ b/torch/csrc/autograd/functions/accumulate_grad.cpp @@ -57,11 +57,6 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { 1 + !post_hooks().empty() /* num_expected_refs */, [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); }); - auto& hook = tensor_post_acc_grad_hooks(); - if (hook != nullptr) { - (*hook)(variable); - } - return variable_list(); } @@ -93,7 +88,6 @@ variable_list AccumulateGrad::apply_with_saved( }); saved.after(variable_copy); saved.after(grad_copy); - return variable_list(); } diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 2efde9d5f2f2..49d8b2ecb4f4 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -50,15 +50,6 @@ struct TORCH_API AccumulateGrad : public Node { // to all other Nodes). So we must lazily read the Tensor hooks here. return impl::hooks(variable); } - - std::unique_ptr& tensor_post_acc_grad_hooks() noexcept - override { - // NB: Since the AccumulateGrad Node is only a weak ref from the Tensor, - // it can be destroyed even though the Tensor is still alive (contrary - // to all other Nodes). So we must lazily read the Tensor hooks here. - return impl::post_acc_grad_hooks(variable); - } - // Given a variable with its current grad as variable_grad, accumulates // new_grad into variable_grad if in place accumulation is possible. // Otherwise, uses 'update_grad' to update the grad for the variable. diff --git a/torch/csrc/autograd/python_hook.cpp b/torch/csrc/autograd/python_hook.cpp index 25cf3b38a711..ecc2939e5963 100644 --- a/torch/csrc/autograd/python_hook.cpp +++ b/torch/csrc/autograd/python_hook.cpp @@ -30,25 +30,22 @@ namespace autograd { namespace { -// This function is called in 4 different cases: +// This function is called in 3 different cases: // 1) TensorPreHook // 2) PreHook // 3) PostHook -// 4) TensorPostAccGradHook // // Depending on the case, args and res can hold different types of objects: // // args: -// TensorPreHook (Tensor,) -// PreHook ((Tensor, ...),) (grad_outputs,) -// PostHook ((Tensor, ...), (Tensor, ...)) (grad_inputs, grad_outputs) -// TensorPostAccGradHook ((Tensor), ()) (tensor,) +// TensorPreHook (Tensor,) +// PreHook ((Tensor, ...),) (grad_outputs,) +// PostHook ((Tensor, ...), (Tensor, ...)) (grad_inputs, grad_outputs) // // res: -// TensorPreHook Tensor -// PreHook ((Tensor, ...),) (grad_outputs,) -// PostHook ((Tensor, ...),) (grad_inputs,) -// TensorPostAccGradHook None +// TensorPreHook Tensor +// PreHook ((Tensor, ...),) (grad_outputs,) +// PostHook ((Tensor, ...),) (grad_inputs,) // // This function returns True if any hook returned non-None value, and False // otherwise. @@ -196,30 +193,6 @@ void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) { } } -PyFunctionTensorPostAccGradHooks::PyFunctionTensorPostAccGradHooks( - PyObject* dict) - : dict(dict) { - Py_INCREF(dict); -} - -PyFunctionTensorPostAccGradHooks::~PyFunctionTensorPostAccGradHooks() { - // If python is already dead, leak the wrapped python objects - if (Py_IsInitialized()) { - pybind11::gil_scoped_acquire gil; - Py_DECREF(dict); - } -} - -auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor) - -> void { - pybind11::gil_scoped_acquire gil; - THPObjectPtr tup(PyTuple_New(1)); - PyTuple_SET_ITEM(tup.get(), 0, THPVariable_Wrap(tensor)); - bool returned_none = !_call_hooks(dict, tup.get()); - TORCH_CHECK( - returned_none, "Tensor post accumulate grad hooks should return None."); -} - } // namespace autograd } // namespace torch diff --git a/torch/csrc/autograd/python_hook.h b/torch/csrc/autograd/python_hook.h index 4c1d933f00d9..d8aa3ebf3bc3 100644 --- a/torch/csrc/autograd/python_hook.h +++ b/torch/csrc/autograd/python_hook.h @@ -34,17 +34,5 @@ struct PyFunctionPostHook : public FunctionPostHook { PyObject* dict; }; -// PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks, -// and it is understandable if you are confused by why it's a subclass. We are -// simply following the precedent of PyFunctionPreHook and PyFunctionPostHook -// above to easily enroll into existing infrastructure. -struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook { - PyFunctionTensorPostAccGradHooks(PyObject* dict); - ~PyFunctionTensorPostAccGradHooks() override; - void operator()(const Variable& tensor) override; - // fall back to the compiled_args of PostAccumulateGradHook superclass - PyObject* dict; -}; - } // namespace autograd } // namespace torch diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 9946e2d17bed..aad0fdd599d0 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -428,7 +428,6 @@ static int THPVariable_clear(THPVariable* self) { return 0; } Py_CLEAR(self->backward_hooks); - Py_CLEAR(self->post_accumulate_grad_hooks); const auto& tensor = THPVariable_Unpack(self); if (tensor.defined()) { // Two situations to consider: @@ -1166,47 +1165,6 @@ int THPVariable_set_backwards_hooks( END_HANDLE_TH_ERRORS_RET(-1) } -PyObject* THPVariable_get_post_accumulate_grad_hooks( - THPVariable* self, - void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "_post_accumulate_grad_hooks"); - } - if (self->post_accumulate_grad_hooks) { - Py_INCREF(self->post_accumulate_grad_hooks); - return self->post_accumulate_grad_hooks; - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -int THPVariable_set_post_accumulate_grad_hooks( - THPVariable* self, - PyObject* obj, - void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_setter( - self, "_post_accumulate_grad_hooks", obj); - } - THPUtils_assertRet( - -1, obj, "Deletion of _post_accumulate_grad_hooks not allowed!"); - if (obj == Py_None) { - obj = nullptr; - } - Py_XINCREF(obj); - Py_CLEAR(self->post_accumulate_grad_hooks); - self->post_accumulate_grad_hooks = obj; - const auto& tensor = THPVariable_Unpack(self); - if (obj) { - torch::autograd::impl::set_post_acc_grad_hooks( - tensor, std::make_unique(obj)); - } - return 0; - END_HANDLE_TH_ERRORS_RET(-1) -} - PyObject* THPVariable_get_base(THPVariable* self, void* unused) { HANDLE_TH_ERRORS if (check_has_torch_function((PyObject*)self)) { @@ -1521,11 +1479,6 @@ static struct PyGetSetDef THPVariable_properties[] = { (setter)THPVariable_set_backwards_hooks, nullptr, nullptr}, - {"_post_accumulate_grad_hooks", - (getter)THPVariable_get_post_accumulate_grad_hooks, - (setter)THPVariable_set_post_accumulate_grad_hooks, - nullptr, - nullptr}, {"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr}, {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr}, {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr}, @@ -2083,7 +2036,6 @@ static int THPVariable_subclass_traverse( // Finally traverse THPVariable special stuff Py_VISIT(var->backward_hooks); - Py_VISIT(var->post_accumulate_grad_hooks); if (!var->cdata.unsafeIsBorrowed()) { const auto& tensor = THPVariable_Unpack(var); if (tensor.defined()) { diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 7ac6d4482bbf..4270052535bd 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -22,10 +22,6 @@ struct THPVariable { // Hooks to be run on backwards pass (corresponds to Python attr // '_backwards_hooks', set by 'register_hook') PyObject* backward_hooks = nullptr; - // Hooks to be run in the backwards pass after accumulate grad, - // i.e., after the .grad has been set (corresponds to Python attr - // '_post_accumulate_grad_hooks', set by 'register_post_accumulate_grad_hook') - PyObject* post_accumulate_grad_hooks = nullptr; }; TORCH_PYTHON_API void registerPythonTensorClass( diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 44b4f2a92b78..254c443838ed 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -363,19 +363,6 @@ void clear_hooks(const at::TensorBase& self) { materialize_autograd_meta(self)->hooks_.clear(); } -void set_post_acc_grad_hooks( - const at::TensorBase& self, - std::unique_ptr dict) { - AutogradMeta* meta = materialize_autograd_meta(self); - meta->post_acc_grad_hooks_ = std::move(dict); -} - -std::unique_ptr& post_acc_grad_hooks( - const Variable& self) { - TORCH_INTERNAL_ASSERT(get_autograd_meta(self)); - return get_autograd_meta(self)->post_acc_grad_hooks_; -} - void set_name(const Variable& self, const std::string& name) { materialize_autograd_meta(self)->name_ = name; } diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 328c15ec6040..f90f29c90786 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -186,12 +186,6 @@ TORCH_API void add_hook( TORCH_API std::vector>& hooks(const Variable&); TORCH_API void clear_hooks(const at::TensorBase&); -TORCH_API void set_post_acc_grad_hooks( - const at::TensorBase&, - std::unique_ptr dict); -TORCH_API std::unique_ptr& post_acc_grad_hooks( - const Variable&); - TORCH_API void create_cpp_hook( const at::TensorBase&, bool is_retains_grad_hooks = false); @@ -237,12 +231,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { std::vector> hooks_; std::shared_ptr cpp_hooks_list_; - // The post_acc_grad_hooks_ field stores only Python hooks - // (PyFunctionTensorPostAccGradHooks) that are called after the - // .grad field has been accumulated into. This is less complicated - // than the hooks_ field, which encapsulates a lot more. - std::unique_ptr post_acc_grad_hooks_ = nullptr; - // Only meaningful on leaf variables (must be false otherwise) bool requires_grad_{false}; diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index d9da62f22bae..e9d5117711c5 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -347,9 +347,6 @@ class CompiledNodeArgs { TORCH_CHECK( fn->retains_grad_hooks().empty(), "retains_grad_hooks not implemented for compiled autograd"); - TORCH_CHECK( - fn->tensor_post_acc_grad_hooks() == nullptr, - "tensor_post_acc_grad_hooks not implemented for compiled autograd"); for (auto& i : fn->tensor_pre_hooks()) { i->compiled_args(*this); } diff --git a/torch/overrides.py b/torch/overrides.py index 782a865327d6..f10d09db60b2 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1205,7 +1205,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.mT.__get__: lambda self: -1, Tensor.mH.__get__: lambda self: -1, Tensor._backward_hooks.__get__: lambda self: -1, - Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1, Tensor._base.__get__: lambda self: -1, Tensor._cdata.__get__: lambda self: -1, Tensor.grad.__get__: lambda self: -1, @@ -1329,7 +1328,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.record_stream: lambda self, stream: -1, Tensor.refine_names: lambda self, names: -1, Tensor.register_hook: lambda self, hook: -1, - Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1, Tensor.rename: lambda self, name: -1, Tensor.repeat: lambda self, *size: -1, Tensor.requires_grad_: lambda self, requires_grad=True: -1,