mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Document complex optimizer semantic behavior (#121667)
<img width="817" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/565b389d-3e86-4767-9fcb-fe075b50aefe"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121667 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
12662900f9
commit
37e563276b
@ -129,9 +129,38 @@ Autograd
|
||||
|
||||
PyTorch supports autograd for complex tensors. The gradient computed is the Conjugate Wirtinger derivative,
|
||||
the negative of which is precisely the direction of steepest descent used in Gradient Descent algorithm. Thus,
|
||||
all the existing optimizers work out of the box with complex parameters. For more details,
|
||||
all the existing optimizers can be implemented to work out of the box with complex parameters. For more details,
|
||||
check out the note :ref:`complex_autograd-doc`.
|
||||
|
||||
|
||||
Optimizers
|
||||
----------
|
||||
|
||||
Semantically, we define stepping through a PyTorch optimizer with complex parameters as being equivalent to stepping
|
||||
through the same optimizer on the :func:`torch.view_as_real` equivalent of the complex params. More concretely:
|
||||
|
||||
::
|
||||
|
||||
>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
|
||||
>>> real_params = [torch.view_as_real(p) for p in params]
|
||||
|
||||
>>> complex_optim = torch.optim.AdamW(params)
|
||||
>>> real_optim = torch.optim.AdamW(real_params)
|
||||
|
||||
|
||||
`real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical
|
||||
discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers
|
||||
and capturable vs default optimizers. For more details, see https://pytorch.org/docs/stable/notes/numerical_accuracy.html.
|
||||
|
||||
Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their
|
||||
`p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the
|
||||
:func:`torch.view_as_real` equivalent will convert a complex tensor to a real tensor with shape :math:`(..., 2)`,
|
||||
whereas splitting a complex tensor into two tensors is 2 tensors of size :math:`(...)`. This distinction has no impact on
|
||||
pointwise optimizers (like AdamW) but will cause slight discrepancy in optimizers that do global reductions (like LBFGS).
|
||||
We currently do not have optimizers that do per-Tensor reductions and thus do not yet define this behavior. Open an issue
|
||||
if you have a use case that requires precisely defining this behavior.
|
||||
|
||||
|
||||
We do not fully support the following subsystems:
|
||||
|
||||
* Quantization
|
||||
|
@ -418,8 +418,8 @@ The short version:
|
||||
the gradients are computed under the assumption that the function is a part of a larger real-valued
|
||||
loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}`
|
||||
(note the conjugation of z), the negative of which is precisely the direction of steepest descent
|
||||
used in Gradient Descent algorithm. Thus, all the existing optimizers work out of
|
||||
the box with complex parameters.
|
||||
used in Gradient Descent algorithm. Thus, there is a viable path in making the existing optimizers
|
||||
work out of the box with complex parameters.
|
||||
- This convention matches TensorFlow's convention for complex
|
||||
differentiation, but is different from JAX (which computes
|
||||
:math:`\frac{\partial L}{\partial z}`).
|
||||
|
Reference in New Issue
Block a user