Add gradient choice detail to autograd doc

Trying to clarify what our backward functions should compute.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76898
Approved by: https://github.com/soulitzer, https://github.com/Lezcano
This commit is contained in:
Alban Desmaison
2022-05-06 21:12:25 +00:00
committed by PyTorch MergeBot
parent c152817926
commit d5210a4269

View File

@ -87,6 +87,22 @@ subject to change and that users should not rely on.
You can control how PyTorch does packing / unpacking with :ref:`saved-tensors-hooks-doc`.
.. _non-differentiable-func-grad:
Gradients for non-differentiable functions
------------------------------------------
The gradient computation using Automatic Differentiation is only valid when each elementary function being used is differentiable.
Unfortunately many of the function we use in practice do not have this property (relu or sqrt at 0 for example).
And even though we cannot always guarantee that the returned gradient will be correct. For example :math:`f(x) = x = \text{relu}(x) - \text{relu}(-x)` will give a 0 gradient at 0 instead of 1 for any value we choose for the gradient of relu at 0.
To try and reduce the impact of this limitation, we define the gradients of the elementary operations by applying the following rules in order:
#. If the function is differentiable and thus a gradient exists at the current point, use it.
#. If the function is convex (at least locally), use the sub-gradient with minimum norm (as it the steepest descent direction, see Exercise 2.7 from "Convex Optimization Algorithms" by Bertsekas, D. P and "Steepest Descent for Optimization Problems with Nondifferentiable Cost Functionals" by Bertsekas, D. P, and Mitter, S. K., 1971. for details and proofs).
#. If the function is concave (at least locally), use the super-gradient with minimum norm (using a similar argument as above).
#. If the function is defined, define the gradient at the current point by continuity (note that :math:`inf` is possible here, for example, :math:`sqrt(0)`). If multiple values are possible, pick one arbitrarily.
#. If the function is not defined (:math:`\sqrt(-1)`, :math:`\log(-1)` or most functions when the input is :math:`nan` for example) then the value used as the gradient is arbitrary (we might also raise an error but that is not guaranteed). Most functions will use :math:`nan` as the gradient, but for performance reasons, some functions will use non-:math:`nan` values (:math:`\log(-1)` for example).
.. _locally-disable-grad-doc:
Locally disabling gradient computation