Update autograd.rst (#101007)

Fixes #ISSUE_NUMBER

typo fix and small change to improve clarity

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101007
Approved by: https://github.com/lezcano, https://github.com/anjali411
This commit is contained in:
Ran Ding
2023-05-12 11:47:51 +00:00
committed by PyTorch MergeBot
parent aa8dcab1ce
commit b5c8d0359c

View File

@ -455,8 +455,8 @@ imaginary of the components of :math:`z` can be expressed in terms of
.. math::
\begin{aligned}
Re(z) &= \frac {z + z^*}{2} \\
Im(z) &= \frac {z - z^*}{2j}
\mathrm{Re}(z) &= \frac {z + z^*}{2} \\
\mathrm{Im}(z) &= \frac {z - z^*}{2j}
\end{aligned}
Wirtinger calculus suggests to study :math:`f(z, z^*)` instead, which is
@ -474,15 +474,15 @@ derivatives w.r.t., the real and imaginary components of :math:`z`.
&= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\
\\
\frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\
&= 1j * (\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*})
&= 1j * \left(\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}\right)
\end{aligned}
From the above equations, we get:
.. math::
\begin{aligned}
\frac{\partial }{\partial z} &= 1/2 * (\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}) \\
\frac{\partial }{\partial z^*} &= 1/2 * (\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y})
\frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}\right) \\
\frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}\right)
\end{aligned}
which is the classic definition of Wirtinger calculus that you would find on `Wikipedia <https://en.wikipedia.org/wiki/Wirtinger_derivatives>`_.
@ -516,7 +516,7 @@ How do these equations translate into complex space :math:``?
.. math::
\begin{aligned}
z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\
&= z_n - \alpha * 1/2 * (\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}) \\
&= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\
&= z_n - \alpha * \frac{\partial L}{\partial z^*}
\end{aligned}
@ -538,9 +538,9 @@ is the loss of the entire computation (producing a real loss) and
:math:`s` is the output of our function. The goal here is to compute
:math:`\frac{\partial L}{\partial z^*}`, where :math:`z` is the input of
the function. It turns out that in the case of real loss, we can
get away with *only* calculating :math:`\frac{\partial L}{\partial z^*}`,
get away with *only* calculating :math:`\frac{\partial L}{\partial s^*}`,
even though the chain rule implies that we also need to
have access to :math:`\frac{\partial L}{\partial z^*}`. If you want
have access to :math:`\frac{\partial L}{\partial s}`. If you want
to skip this derivation, look at the last equation in this section
and then skip to the next section.
@ -558,8 +558,8 @@ Now using Wirtinger derivative definition, we can write:
.. math::
\begin{aligned}
\frac{\partial L}{\partial s} = 1/2 * (\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j) \\
\frac{\partial L}{\partial s^*} = 1/2 * (\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j)
\frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\
\frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right)
\end{aligned}
It should be noted here that since :math:`u` and :math:`v` are real
@ -567,7 +567,7 @@ functions, and :math:`L` is real by our assumption that :math:`f` is a
part of a real valued function, we have:
.. math::
(\frac{\partial L}{\partial s})^* = \frac{\partial L}{\partial s^*}
\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}
:label: [2]
i.e., :math:`\frac{\partial L}{\partial s}` equals to :math:`grad\_output^*`.
@ -577,7 +577,7 @@ Solving the above equations for :math:`\frac{\partial L}{\partial u}` and :math:
.. math::
\begin{aligned}
\frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\
\frac{\partial L}{\partial v} = -1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*})
\frac{\partial L}{\partial v} = -1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right)
\end{aligned}
:label: [3]
@ -585,8 +585,8 @@ Substituting :eq:`[3]` in :eq:`[1]`, we get:
.. math::
\begin{aligned}
\frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) * \frac{\partial u}{\partial z^*} - 1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) * \frac{\partial v}{\partial z^*} \\
&= \frac{\partial L}{\partial s} * (\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j) + \frac{\partial L}{\partial s^*} * (\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j) \\
\frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} - 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\
&= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\
&= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\
&= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\
\end{aligned}
@ -595,8 +595,8 @@ Using :eq:`[2]`, we get:
.. math::
\begin{aligned}
\frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s^*})^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * (\frac{\partial s}{\partial z})^* \\
&= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * {(\frac{\partial s}{\partial z})}^* } \\
\frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\
&= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * \left(\frac{\partial s}{\partial z}\right)^* } \\
\end{aligned}
:label: [4]
@ -624,12 +624,12 @@ Using the first way to compute the Wirtinger derivatives, we have.
.. math::
\begin{aligned}
\frac{\partial s}{\partial z} &= 1/2 * (\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j) \\
\frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\
&= 1/2 * (c - (c * 1j) * 1j) \\
&= c \\
\\
\\
\frac{\partial s}{\partial z^*} &= 1/2 * (\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j) \\
\frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\
&= 1/2 * (c + (c * 1j) * 1j) \\
&= 0 \\
\end{aligned}
@ -667,7 +667,7 @@ chain rule:
- For :math:`f: `, we get:
.. math::
\frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}})
\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})
.. _saved-tensors-hooks-doc: