[docs] Add section about tensor hooks on in-place in autograd note (#93116)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93116
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2023-01-31 19:47:12 +00:00
committed by PyTorch MergeBot
parent 76b999803a
commit 77cbaedd5c

View File

@ -889,4 +889,33 @@ registered to Node. As the forward is computed, hooks are registered to grad_fn
to the inputs and outputs of the module. Because a module may take multiple inputs and return
multiple outputs, a dummy custom autograd Function is first applied to the inputs of the module
before forward and the outputs of the module before the output of forward is returned to ensure
that those tensors share a single grad_fn, which we can then attach our hooks to.
that those Tensors share a single grad_fn, which we can then attach our hooks to.
Behavior of Tensor hooks when Tensor is modified in-place
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Usually hooks registered to a Tensor receive the gradient of the outputs with respect to that
Tensor, where the value of the Tensor is taken to be its value at the time backward is computed.
However, if you register hooks to a Tensor, and then modify that Tensor in-place, hooks
registered before in-place modification similarly receive gradients of the outputs with
respect to the Tensor, but the value of the Tensor is taken to be its value before
in-place modification.
If you prefer the behavior in the former case,
you should register them to the Tensor after all in-place modifications to it have been made.
For example:
.. code::
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
Furthemore, it can be helpful to know that under the hood,
when hooks are registered to a Tensor, they actually become permanently bound to the grad_fn
of that Tensor, so if that Tensor is then modified in-place,
even though the Tensor now has a new grad_fn, hooks registered before it was
modified in-place will continue to be associated with the old grad_fn, e.g. they will
fire when that Tensor's old grad_fn is reached in the graph by the autograd engine.