Doc note update for complex autograd (#45270)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45270

<img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png">
<img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png">
<img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png">
<img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png">
<img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png">
<img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png">

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D24203257

Pulled By: anjali411

fbshipit-source-id: cd637dade5fb40cecf5d9f4bd03d508d36e26fcd
This commit is contained in:
anjali411
2020-10-08 14:59:54 -07:00
committed by Facebook GitHub Bot
parent e3112e3ed6
commit 89256611b5

View File

@ -214,80 +214,278 @@ proper thread locking code to ensure the hooks are thread safe.
.. _complex_autograd-doc: .. _complex_autograd-doc:
Autograd for Complex Numbers Autograd for Complex Numbers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ----------------------------
**What notion of complex derivative does PyTorch use?** The short version:
*******************************************************
PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_ - When you use PyTorch to differentiate any function :math:`f(z)` with complex domain and/or codomain,
convention for autograd for Complex Numbers. the gradients are computed under the assumption that the function is a part of a larger real-valued
loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}`
(note the conjugation of z), which is precisely the direction of the step
you should take in gradient descent. Thus, all the existing optimizers work out of
the box with complex parameters.
- This convention matches TensorFlow's convention for complex
differentiation, but is different from JAX (which computes
:math:`\frac{\partial L}{\partial z}`).
- If you have a real-to-real function which internally uses complex
operations, the convention here doesn't matter: you will always get
the same result that you would have gotten if it had been implemented
with only real operations.
Suppose we have a function :math:`F: ` which we can decompose into functions u and v If you are curious about the mathematical details, or want to know how
which compute the real and imaginary parts of the function: to define complex derivatives in PyTorch, read on.
.. code:: What are complex derivatives?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
def F(z): The mathematical definition of complex-differentiability takes the
x, y = real(z), imag(z) limit definition of a derivative and generalizes it to operate on
return u(x, y) + v(x, y) * 1j complex numbers. For a function :math:`f: `, we can write:
where :math:`1j` is a unit imaginary number.
We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent
vector :math:`c+dj \in C` as:
.. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix}
where
.. math:: .. math::
J = \begin{bmatrix} f'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}
\frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\
\frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\
This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication In order for this limit to exist, not only must :math:`u` and :math:`v` must be
with :math:`[1, 1j]^T` is used to identify the result as a complex number. real differentiable (as above), but :math:`f` must also satisfy the Cauchy-Riemann `equations
<https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations>`_. In
other words: the limit computed for real and imaginary steps (:math:`h`)
must be equal. This is a more restrictive condition.
We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as: The complex differentiable functions are commonly known as holomorphic
functions. They are well behaved, have all the nice properties that
you've seen from real differentiable functions, but are practically of no
use in the optimization world. For optimization problems, only real valued objective
functions are used in the research community since complex numbers are not part of any
ordered field and so having complex valued loss does not make much sense.
.. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} It also turns out that no interesting real-valued objective fulfill the
Cauchy-Riemann equations. So the theory with homomorphic function cannot be
used for optimization and most people therefore use the Wirtinger calculus.
In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward Wirtinger Calculus comes in picture ...
mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
to get explanation for the negative signs in the formula.
**What happens if I call backward() on a complex scalar?** So, we have this great theory of complex differentiability and
******************************************************************************* holomorphic functions, and we cant use any of it at all, because many
of the commonly used functions are not holomorphic. Whats a poor
mathematician to do? Well, Wirtinger observed that even if :math:`f(z)`
isnt holomorphic, one could rewrite it as a two variable function
:math:`f(z, z*)` which is always holomorphic. This is because real and
imaginary of the components of :math:`z` can be expressed in terms of
:math:`z` and :math:`z^*` as:
The gradient for a complex function is computed assuming the input function is a holomorphic function. .. math::
This is because for general :math:`` functions, the Jacobian has 4 real-valued degrees of freedom \begin{aligned}
(as in the `2x2` Jacobian matrix above), so we cant hope to represent all of them with in a complex number. Re(z) &= \frac {z + z^*}{2} \\
However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the Im(z) &= \frac {z - z^*}{2j}
Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate \end{aligned}
matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
obtain that gradient using backward which is just a call to `vjp` with covector `1.0`.
The net effect of this assumption is that the partial derivatives of the imaginary part of the function Wirtinger calculus suggests to study :math:`f(z, z^*)` instead, which is
(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar guaranteed to be holomorphic if :math:`f` was real differentiable (another
(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards). way to think of it is as a change of coordinate system, from :math:`f(x, y)`
to :math:`f(z, z^*)`.) This function has partial derivatives
:math:`\frac{\partial }{\partial z}` and :math:`\frac{\partial}{\partial z^{*}}`.
We can use the chain rule to establish a
relationship between these partial derivatives and the partial
derivatives w.r.t., the real and imaginary components of :math:`z`.
For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly. .. math::
\begin{aligned}
\frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\
&= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\
\\
\frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\
&= 1j * (\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*})
\end{aligned}
**How are the JVP and VJP defined for cross-domain functions?** From the above equations, we get:
***************************************************************
Based on formulas above and the behavior we expect to see (going from :math:`^2 → ` should be an identity), .. math::
we use the formula given below for cross-domain functions. \begin{aligned}
\frac{\partial }{\partial z} &= 1/2 * (\frac{\partial }{\partial x} - 1j * \frac{\partial z}{\partial y}) \\
\frac{\partial }{\partial z^*} &= 1/2 * (\frac{\partial }{\partial x} + 1j * \frac{\partial z}{\partial y})
\end{aligned}
The :math:`JVP` and :math:`VJP` for a :math:`f1: ^2` are defined as: which is the classic definition of Wirtinger calculus that you would find on `Wikipedia <https://en.wikipedia.org/wiki/Wirtinger_derivatives>`_.
.. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix} There are a lot of beautiful consequences of this change.
.. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} - For one, the Cauchy-Riemann equations translate into simply saying that :math:`\frac{\partial f}{\partial z^*} = 0` (that is to say, the function :math:`f` can be written
entirely in terms of :math:`z`, without making reference to :math:`z^*`).
- Another important (and somewhat counterintuitive) result, as well see later, is that when we do optimization on a real-valued loss, the step we should
take while making variable update is given by :math:`\frac{\partial Loss}{\partial z^*}` (not :math:`\frac{\partial Loss}{\partial z}`).
The :math:`JVP` and :math:`VJP` for a :math:`f1: ^2 → ` are defined as: For more reading, check out: https://arxiv.org/pdf/0906.4835.pdf
.. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\ How is Wirtinger Calculus useful in optimization?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J Researchers in audio and other fields, more commonly, use gradient
descent to optimize real valued loss functions with complex variables.
Typically, these people treat the real and imaginary values as separate
channels that can be updated. For a step size :math:`s/2` and loss
:math:`L`, we can write the following equations in :math:`^2`:
.. math::
\begin{aligned}
x_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} \\
y_{n+1} &= y_n - (s/2) * \frac{\partial L}{\partial y}
\end{aligned}
How do these equations translate into complex space :math:``?
.. math::
\begin{aligned}
z_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (s/2) * \frac{\partial L}{\partial y})
&= z_n - s * 1/2 * (\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y})
&= z_n - s * \frac{\partial L}{\partial z^*}
\end{aligned}
Something very interesting has happened: Wirtinger calculus tells us
that we can simplify the complex variable update formula above to only
refer to the conjugate Wirtinger derivative
:math:`\frac{\partial L}{\partial z^*}`, giving us exactly the step we take in optimization.
Because the conjugate Wirtinger derivative gives us exactly the correct step for a real valued loss function, PyTorch gives you this derivative
when you differentiate a function with a real valued loss.
How does PyTorch compute the conjugate Wirtinger derivative?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Typically, our derivative formulas take in `grad_output` as an input,
representing the incoming Vector-Jacobian product that weve already
computed, aka, :math:`\frac{\partial L}{\partial s^*}`, where :math:`L`
is the loss of the entire computation (producing a real loss) and
:math:`s` is the output of our function. The goal here is to compute
:math:`\frac{\partial L}{\partial z^*}`, where :math:`z` is the input of
the function. It turns out that in the case of real loss, we can
get away with *only* calculating :math:`\frac{\partial L}{\partial z^*}`,
even though the chain rule implies that we also need to
have access to :math:`\frac{\partial L}{\partial z^*}`. If you want
to skip this derivation, look at the last equation in this section
and then skip to the next section.
Lets continue working with :math:`f: ` defined as
:math:`f(z) = f(x+yj) = u(x, y) + v(x, y)j`. As discussed above,
autograds gradient convention is centered around optimization for real
valued loss functions, so lets assume :math:`f` is a part of larger
real valued loss function :math:`g`. Using chain rule, we can write:
.. math::
\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}
:label: [1]
Now using Wirtinger derivative definition, we can write:
.. math::
\begin{aligned}
\frac{\partial L}{\partial s} = 1/2 * (\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j) \\
\frac{\partial L}{\partial s^*} = 1/2 * (\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j)
\end{aligned}
It should be noted here that since :math:`u` and :math:`v` are real
functions, and :math:`L` is real by our assumption that :math:`f` is a
part of a real valued function, we have:
.. math::
(\frac{\partial L}{\partial s})^* = \frac{\partial L}{\partial s^*}
:label: [2]
i.e., :math:`\frac{\partial L}{\partial s}` equals to :math:`grad\_output^*`.
Solving the above equations for :math:`\frac{\partial L}{\partial u}` and :math:`\frac{\partial L}{\partial v}`, we get:
.. math::
\begin{aligned}
\frac{\partial L}{\partial u} = 1/2 * (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) \\
\frac{\partial L}{\partial v} = -1/2j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*})
\end{aligned}
:label: [3]
Substituting :eq:`[3]` in :eq:`[1]`, we get:
.. math::
\begin{aligned}
\frac{\partial L}{\partial z^*} &= 1/2 * (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) * \frac{\partial u}{\partial z^*} - 1/2j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) * \frac{\partial v}{\partial z^*} \\
&= \frac{\partial L}{\partial s} * 1/2 * (\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j) + \frac{\partial L}{\partial s^*} * 1/2 * (\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j) \\
&= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\
&= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\
\end{aligned}
Using :eq:`[2]`, we get:
.. math::
\begin{aligned}
\frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s^*})^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * (\frac{\partial s}{\partial z})^* \\
&= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * {(\frac{\partial s}{\partial z})}^* } \\
\end{aligned}
:label: [4]
This last equation is the important one for writing your own gradients,
as it decomposes our derivative formula into a simpler one that is easy
to compute by hand.
How can I write my own derivative formula for a complex function?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The above boxed equation gives us the general formula for all
derivatives on complex functions. However, we still need to
compute :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}`.
There are two ways you could do this:
- The first way is to just use the definition of Wirtinger derivatives directly and calculate :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}` by
using :math:`\frac{\partial s}{\partial x}` and :math:`\frac{\partial s}{\partial y}`
(which you can compute in the normal way).
- The second way is to use the change of variables trick and rewrite :math:`f(z)` as a two variable function :math:`f(z, z^*)`, and compute
the conjugate Wirtinger derivatives by treating :math:`z` and :math:`z^*` as independent variables. This is often easier; for example, if the function in question is holomorphic, only :math:`z` will be used (and :math:`\frac{\partial s}{\partial z^*}` will be zero).
Let's consider the function :math:`f(z = x + yj) = c * z = c * (x+yj)` as an example, where :math:`c \in `.
Using the first way to compute the Wirtinger derivatives, we have.
.. math::
\begin{aligned}
\frac{\partial s}{\partial z} &= 1/2 * (\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j) \\
&= 1/2 * (c - (c * 1j) * 1j) \\
&= c \\
\\
\\
\frac{\partial s}{\partial z^*} &= 1/2 * (\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j) \\
&= 1/2 * (c + (c * 1j) * 1j) \\
&= 0 \\
\end{aligned}
Using :eq:`[4]`, and `grad\_output = 1.0` (which is the default grad output value used when :func:`backward` is called on a scalar output in PyTorch), we get:
.. math::
\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c
Using the second way to compute Wirtinger derivatives, we directly get:
.. math::
\begin{aligned}
\frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\
&= c \\
\frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\
&= 0
\end{aligned}
And using :eq:`[4]` again, we get :math:`\frac{\partial L}{\partial z^*} = c`. As you can see, the second way involves lesser calculations, and comes
in more handy for faster calculations.
What about cross-domain functions?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Some functions map from complex inputs to real outputs, or vice versa.
These functions form a special case of :eq:`[4]`, which we can derive using the
chain rule:
- For :math:`f: `, we get:
.. math::
\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
- For :math:`f: `, we get:
.. math::
\frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}})