Document complex optimizer semantic behavior (#121667)

<img width="817" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/565b389d-3e86-4767-9fcb-fe075b50aefe">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121667
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-03-15 08:30:54 -07:00
committed by PyTorch MergeBot
parent 12662900f9
commit 37e563276b
2 changed files with 32 additions and 3 deletions

View File

@ -129,9 +129,38 @@ Autograd
PyTorch supports autograd for complex tensors. The gradient computed is the Conjugate Wirtinger derivative,
the negative of which is precisely the direction of steepest descent used in Gradient Descent algorithm. Thus,
all the existing optimizers work out of the box with complex parameters. For more details,
all the existing optimizers can be implemented to work out of the box with complex parameters. For more details,
check out the note :ref:`complex_autograd-doc`.
Optimizers
----------
Semantically, we define stepping through a PyTorch optimizer with complex parameters as being equivalent to stepping
through the same optimizer on the :func:`torch.view_as_real` equivalent of the complex params. More concretely:
::
>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
>>> real_params = [torch.view_as_real(p) for p in params]
>>> complex_optim = torch.optim.AdamW(params)
>>> real_optim = torch.optim.AdamW(real_params)
`real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical
discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers
and capturable vs default optimizers. For more details, see https://pytorch.org/docs/stable/notes/numerical_accuracy.html.
Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their
`p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the
:func:`torch.view_as_real` equivalent will convert a complex tensor to a real tensor with shape :math:`(..., 2)`,
whereas splitting a complex tensor into two tensors is 2 tensors of size :math:`(...)`. This distinction has no impact on
pointwise optimizers (like AdamW) but will cause slight discrepancy in optimizers that do global reductions (like LBFGS).
We currently do not have optimizers that do per-Tensor reductions and thus do not yet define this behavior. Open an issue
if you have a use case that requires precisely defining this behavior.
We do not fully support the following subsystems:
* Quantization

View File

@ -418,8 +418,8 @@ The short version:
the gradients are computed under the assumption that the function is a part of a larger real-valued
loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}`
(note the conjugation of z), the negative of which is precisely the direction of steepest descent
used in Gradient Descent algorithm. Thus, all the existing optimizers work out of
the box with complex parameters.
used in Gradient Descent algorithm. Thus, there is a viable path in making the existing optimizers
work out of the box with complex parameters.
- This convention matches TensorFlow's convention for complex
differentiation, but is different from JAX (which computes
:math:`\frac{\partial L}{\partial z}`).