mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Doc note update for complex autograd (#45270)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45270 <img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D24203257 Pulled By: anjali411 fbshipit-source-id: cd637dade5fb40cecf5d9f4bd03d508d36e26fcd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e3112e3ed6
commit
89256611b5
@ -214,80 +214,278 @@ proper thread locking code to ensure the hooks are thread safe.
|
||||
.. _complex_autograd-doc:
|
||||
|
||||
Autograd for Complex Numbers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
----------------------------
|
||||
|
||||
**What notion of complex derivative does PyTorch use?**
|
||||
*******************************************************
|
||||
The short version:
|
||||
|
||||
PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
|
||||
convention for autograd for Complex Numbers.
|
||||
- When you use PyTorch to differentiate any function :math:`f(z)` with complex domain and/or codomain,
|
||||
the gradients are computed under the assumption that the function is a part of a larger real-valued
|
||||
loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}`
|
||||
(note the conjugation of z), which is precisely the direction of the step
|
||||
you should take in gradient descent. Thus, all the existing optimizers work out of
|
||||
the box with complex parameters.
|
||||
- This convention matches TensorFlow's convention for complex
|
||||
differentiation, but is different from JAX (which computes
|
||||
:math:`\frac{\partial L}{\partial z}`).
|
||||
- If you have a real-to-real function which internally uses complex
|
||||
operations, the convention here doesn't matter: you will always get
|
||||
the same result that you would have gotten if it had been implemented
|
||||
with only real operations.
|
||||
|
||||
Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v
|
||||
which compute the real and imaginary parts of the function:
|
||||
If you are curious about the mathematical details, or want to know how
|
||||
to define complex derivatives in PyTorch, read on.
|
||||
|
||||
.. code::
|
||||
What are complex derivatives?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
def F(z):
|
||||
x, y = real(z), imag(z)
|
||||
return u(x, y) + v(x, y) * 1j
|
||||
|
||||
where :math:`1j` is a unit imaginary number.
|
||||
|
||||
We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent
|
||||
vector :math:`c+dj \in C` as:
|
||||
|
||||
.. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix}
|
||||
|
||||
where
|
||||
The mathematical definition of complex-differentiability takes the
|
||||
limit definition of a derivative and generalizes it to operate on
|
||||
complex numbers. For a function :math:`f: ℂ → ℂ`, we can write:
|
||||
|
||||
.. math::
|
||||
J = \begin{bmatrix}
|
||||
\frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\
|
||||
\frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\
|
||||
f'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}
|
||||
|
||||
This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication
|
||||
with :math:`[1, 1j]^T` is used to identify the result as a complex number.
|
||||
In order for this limit to exist, not only must :math:`u` and :math:`v` must be
|
||||
real differentiable (as above), but :math:`f` must also satisfy the Cauchy-Riemann `equations
|
||||
<https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations>`_. In
|
||||
other words: the limit computed for real and imaginary steps (:math:`h`)
|
||||
must be equal. This is a more restrictive condition.
|
||||
|
||||
We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as:
|
||||
The complex differentiable functions are commonly known as holomorphic
|
||||
functions. They are well behaved, have all the nice properties that
|
||||
you've seen from real differentiable functions, but are practically of no
|
||||
use in the optimization world. For optimization problems, only real valued objective
|
||||
functions are used in the research community since complex numbers are not part of any
|
||||
ordered field and so having complex valued loss does not make much sense.
|
||||
|
||||
.. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
|
||||
It also turns out that no interesting real-valued objective fulfill the
|
||||
Cauchy-Riemann equations. So the theory with homomorphic function cannot be
|
||||
used for optimization and most people therefore use the Wirtinger calculus.
|
||||
|
||||
In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward
|
||||
mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at
|
||||
the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
|
||||
to get explanation for the negative signs in the formula.
|
||||
Wirtinger Calculus comes in picture ...
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
**What happens if I call backward() on a complex scalar?**
|
||||
*******************************************************************************
|
||||
So, we have this great theory of complex differentiability and
|
||||
holomorphic functions, and we can’t use any of it at all, because many
|
||||
of the commonly used functions are not holomorphic. What’s a poor
|
||||
mathematician to do? Well, Wirtinger observed that even if :math:`f(z)`
|
||||
isn’t holomorphic, one could rewrite it as a two variable function
|
||||
:math:`f(z, z*)` which is always holomorphic. This is because real and
|
||||
imaginary of the components of :math:`z` can be expressed in terms of
|
||||
:math:`z` and :math:`z^*` as:
|
||||
|
||||
The gradient for a complex function is computed assuming the input function is a holomorphic function.
|
||||
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
|
||||
(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
|
||||
However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
|
||||
Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate
|
||||
matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
|
||||
obtain that gradient using backward which is just a call to `vjp` with covector `1.0`.
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
Re(z) &= \frac {z + z^*}{2} \\
|
||||
Im(z) &= \frac {z - z^*}{2j}
|
||||
\end{aligned}
|
||||
|
||||
The net effect of this assumption is that the partial derivatives of the imaginary part of the function
|
||||
(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar
|
||||
(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
|
||||
Wirtinger calculus suggests to study :math:`f(z, z^*)` instead, which is
|
||||
guaranteed to be holomorphic if :math:`f` was real differentiable (another
|
||||
way to think of it is as a change of coordinate system, from :math:`f(x, y)`
|
||||
to :math:`f(z, z^*)`.) This function has partial derivatives
|
||||
:math:`\frac{\partial }{\partial z}` and :math:`\frac{\partial}{\partial z^{*}}`.
|
||||
We can use the chain rule to establish a
|
||||
relationship between these partial derivatives and the partial
|
||||
derivatives w.r.t., the real and imaginary components of :math:`z`.
|
||||
|
||||
For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly.
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\
|
||||
&= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\
|
||||
\\
|
||||
\frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\
|
||||
&= 1j * (\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*})
|
||||
\end{aligned}
|
||||
|
||||
**How are the JVP and VJP defined for cross-domain functions?**
|
||||
***************************************************************
|
||||
From the above equations, we get:
|
||||
|
||||
Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity),
|
||||
we use the formula given below for cross-domain functions.
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial }{\partial z} &= 1/2 * (\frac{\partial }{\partial x} - 1j * \frac{\partial z}{\partial y}) \\
|
||||
\frac{\partial }{\partial z^*} &= 1/2 * (\frac{\partial }{\partial x} + 1j * \frac{\partial z}{\partial y})
|
||||
\end{aligned}
|
||||
|
||||
The :math:`JVP` and :math:`VJP` for a :math:`f1: ℂ → ℝ^2` are defined as:
|
||||
which is the classic definition of Wirtinger calculus that you would find on `Wikipedia <https://en.wikipedia.org/wiki/Wirtinger_derivatives>`_.
|
||||
|
||||
.. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix}
|
||||
There are a lot of beautiful consequences of this change.
|
||||
|
||||
.. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
|
||||
- For one, the Cauchy-Riemann equations translate into simply saying that :math:`\frac{\partial f}{\partial z^*} = 0` (that is to say, the function :math:`f` can be written
|
||||
entirely in terms of :math:`z`, without making reference to :math:`z^*`).
|
||||
- Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we should
|
||||
take while making variable update is given by :math:`\frac{\partial Loss}{\partial z^*}` (not :math:`\frac{\partial Loss}{\partial z}`).
|
||||
|
||||
The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as:
|
||||
For more reading, check out: https://arxiv.org/pdf/0906.4835.pdf
|
||||
|
||||
.. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\
|
||||
How is Wirtinger Calculus useful in optimization?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J
|
||||
Researchers in audio and other fields, more commonly, use gradient
|
||||
descent to optimize real valued loss functions with complex variables.
|
||||
Typically, these people treat the real and imaginary values as separate
|
||||
channels that can be updated. For a step size :math:`s/2` and loss
|
||||
:math:`L`, we can write the following equations in :math:`ℝ^2`:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
x_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} \\
|
||||
y_{n+1} &= y_n - (s/2) * \frac{\partial L}{\partial y}
|
||||
\end{aligned}
|
||||
|
||||
How do these equations translate into complex space :math:`ℂ`?
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
z_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (s/2) * \frac{\partial L}{\partial y})
|
||||
&= z_n - s * 1/2 * (\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y})
|
||||
&= z_n - s * \frac{\partial L}{\partial z^*}
|
||||
\end{aligned}
|
||||
|
||||
Something very interesting has happened: Wirtinger calculus tells us
|
||||
that we can simplify the complex variable update formula above to only
|
||||
refer to the conjugate Wirtinger derivative
|
||||
:math:`\frac{\partial L}{\partial z^*}`, giving us exactly the step we take in optimization.
|
||||
|
||||
Because the conjugate Wirtinger derivative gives us exactly the correct step for a real valued loss function, PyTorch gives you this derivative
|
||||
when you differentiate a function with a real valued loss.
|
||||
|
||||
How does PyTorch compute the conjugate Wirtinger derivative?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Typically, our derivative formulas take in `grad_output` as an input,
|
||||
representing the incoming Vector-Jacobian product that we’ve already
|
||||
computed, aka, :math:`\frac{\partial L}{\partial s^*}`, where :math:`L`
|
||||
is the loss of the entire computation (producing a real loss) and
|
||||
:math:`s` is the output of our function. The goal here is to compute
|
||||
:math:`\frac{\partial L}{\partial z^*}`, where :math:`z` is the input of
|
||||
the function. It turns out that in the case of real loss, we can
|
||||
get away with *only* calculating :math:`\frac{\partial L}{\partial z^*}`,
|
||||
even though the chain rule implies that we also need to
|
||||
have access to :math:`\frac{\partial L}{\partial z^*}`. If you want
|
||||
to skip this derivation, look at the last equation in this section
|
||||
and then skip to the next section.
|
||||
|
||||
Let’s continue working with :math:`f: ℂ → ℂ` defined as
|
||||
:math:`f(z) = f(x+yj) = u(x, y) + v(x, y)j`. As discussed above,
|
||||
autograd’s gradient convention is centered around optimization for real
|
||||
valued loss functions, so let’s assume :math:`f` is a part of larger
|
||||
real valued loss function :math:`g`. Using chain rule, we can write:
|
||||
|
||||
.. math::
|
||||
\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}
|
||||
:label: [1]
|
||||
|
||||
Now using Wirtinger derivative definition, we can write:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial L}{\partial s} = 1/2 * (\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j) \\
|
||||
\frac{\partial L}{\partial s^*} = 1/2 * (\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j)
|
||||
\end{aligned}
|
||||
|
||||
It should be noted here that since :math:`u` and :math:`v` are real
|
||||
functions, and :math:`L` is real by our assumption that :math:`f` is a
|
||||
part of a real valued function, we have:
|
||||
|
||||
.. math::
|
||||
(\frac{\partial L}{\partial s})^* = \frac{\partial L}{\partial s^*}
|
||||
:label: [2]
|
||||
|
||||
i.e., :math:`\frac{\partial L}{\partial s}` equals to :math:`grad\_output^*`.
|
||||
|
||||
Solving the above equations for :math:`\frac{\partial L}{\partial u}` and :math:`\frac{\partial L}{\partial v}`, we get:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial L}{\partial u} = 1/2 * (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) \\
|
||||
\frac{\partial L}{\partial v} = -1/2j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*})
|
||||
\end{aligned}
|
||||
:label: [3]
|
||||
|
||||
Substituting :eq:`[3]` in :eq:`[1]`, we get:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial L}{\partial z^*} &= 1/2 * (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) * \frac{\partial u}{\partial z^*} - 1/2j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) * \frac{\partial v}{\partial z^*} \\
|
||||
&= \frac{\partial L}{\partial s} * 1/2 * (\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j) + \frac{\partial L}{\partial s^*} * 1/2 * (\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j) \\
|
||||
&= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\
|
||||
&= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\
|
||||
\end{aligned}
|
||||
|
||||
Using :eq:`[2]`, we get:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s^*})^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * (\frac{\partial s}{\partial z})^* \\
|
||||
&= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * {(\frac{\partial s}{\partial z})}^* } \\
|
||||
\end{aligned}
|
||||
:label: [4]
|
||||
|
||||
This last equation is the important one for writing your own gradients,
|
||||
as it decomposes our derivative formula into a simpler one that is easy
|
||||
to compute by hand.
|
||||
|
||||
How can I write my own derivative formula for a complex function?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The above boxed equation gives us the general formula for all
|
||||
derivatives on complex functions. However, we still need to
|
||||
compute :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}`.
|
||||
There are two ways you could do this:
|
||||
|
||||
- The first way is to just use the definition of Wirtinger derivatives directly and calculate :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}` by
|
||||
using :math:`\frac{\partial s}{\partial x}` and :math:`\frac{\partial s}{\partial y}`
|
||||
(which you can compute in the normal way).
|
||||
- The second way is to use the change of variables trick and rewrite :math:`f(z)` as a two variable function :math:`f(z, z^*)`, and compute
|
||||
the conjugate Wirtinger derivatives by treating :math:`z` and :math:`z^*` as independent variables. This is often easier; for example, if the function in question is holomorphic, only :math:`z` will be used (and :math:`\frac{\partial s}{\partial z^*}` will be zero).
|
||||
|
||||
Let's consider the function :math:`f(z = x + yj) = c * z = c * (x+yj)` as an example, where :math:`c \in ℝ`.
|
||||
|
||||
Using the first way to compute the Wirtinger derivatives, we have.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial s}{\partial z} &= 1/2 * (\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j) \\
|
||||
&= 1/2 * (c - (c * 1j) * 1j) \\
|
||||
&= c \\
|
||||
\\
|
||||
\\
|
||||
\frac{\partial s}{\partial z^*} &= 1/2 * (\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j) \\
|
||||
&= 1/2 * (c + (c * 1j) * 1j) \\
|
||||
&= 0 \\
|
||||
\end{aligned}
|
||||
|
||||
Using :eq:`[4]`, and `grad\_output = 1.0` (which is the default grad output value used when :func:`backward` is called on a scalar output in PyTorch), we get:
|
||||
|
||||
.. math::
|
||||
\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c
|
||||
|
||||
Using the second way to compute Wirtinger derivatives, we directly get:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\
|
||||
&= c \\
|
||||
\frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\
|
||||
&= 0
|
||||
\end{aligned}
|
||||
|
||||
And using :eq:`[4]` again, we get :math:`\frac{\partial L}{\partial z^*} = c`. As you can see, the second way involves lesser calculations, and comes
|
||||
in more handy for faster calculations.
|
||||
|
||||
What about cross-domain functions?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Some functions map from complex inputs to real outputs, or vice versa.
|
||||
These functions form a special case of :eq:`[4]`, which we can derive using the
|
||||
chain rule:
|
||||
|
||||
- For :math:`f: ℂ → ℝ`, we get:
|
||||
|
||||
.. math::
|
||||
\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
|
||||
|
||||
- For :math:`f: ℝ → ℂ`, we get:
|
||||
|
||||
.. math::
|
||||
\frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}})
|
||||
|
Reference in New Issue
Block a user