mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Document complex optimizer semantic behavior (#121667)
<img width="817" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/565b389d-3e86-4767-9fcb-fe075b50aefe"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121667 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
12662900f9
commit
37e563276b
@ -129,9 +129,38 @@ Autograd
|
|||||||
|
|
||||||
PyTorch supports autograd for complex tensors. The gradient computed is the Conjugate Wirtinger derivative,
|
PyTorch supports autograd for complex tensors. The gradient computed is the Conjugate Wirtinger derivative,
|
||||||
the negative of which is precisely the direction of steepest descent used in Gradient Descent algorithm. Thus,
|
the negative of which is precisely the direction of steepest descent used in Gradient Descent algorithm. Thus,
|
||||||
all the existing optimizers work out of the box with complex parameters. For more details,
|
all the existing optimizers can be implemented to work out of the box with complex parameters. For more details,
|
||||||
check out the note :ref:`complex_autograd-doc`.
|
check out the note :ref:`complex_autograd-doc`.
|
||||||
|
|
||||||
|
|
||||||
|
Optimizers
|
||||||
|
----------
|
||||||
|
|
||||||
|
Semantically, we define stepping through a PyTorch optimizer with complex parameters as being equivalent to stepping
|
||||||
|
through the same optimizer on the :func:`torch.view_as_real` equivalent of the complex params. More concretely:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
|
||||||
|
>>> real_params = [torch.view_as_real(p) for p in params]
|
||||||
|
|
||||||
|
>>> complex_optim = torch.optim.AdamW(params)
|
||||||
|
>>> real_optim = torch.optim.AdamW(real_params)
|
||||||
|
|
||||||
|
|
||||||
|
`real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical
|
||||||
|
discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers
|
||||||
|
and capturable vs default optimizers. For more details, see https://pytorch.org/docs/stable/notes/numerical_accuracy.html.
|
||||||
|
|
||||||
|
Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their
|
||||||
|
`p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the
|
||||||
|
:func:`torch.view_as_real` equivalent will convert a complex tensor to a real tensor with shape :math:`(..., 2)`,
|
||||||
|
whereas splitting a complex tensor into two tensors is 2 tensors of size :math:`(...)`. This distinction has no impact on
|
||||||
|
pointwise optimizers (like AdamW) but will cause slight discrepancy in optimizers that do global reductions (like LBFGS).
|
||||||
|
We currently do not have optimizers that do per-Tensor reductions and thus do not yet define this behavior. Open an issue
|
||||||
|
if you have a use case that requires precisely defining this behavior.
|
||||||
|
|
||||||
|
|
||||||
We do not fully support the following subsystems:
|
We do not fully support the following subsystems:
|
||||||
|
|
||||||
* Quantization
|
* Quantization
|
||||||
|
@ -418,8 +418,8 @@ The short version:
|
|||||||
the gradients are computed under the assumption that the function is a part of a larger real-valued
|
the gradients are computed under the assumption that the function is a part of a larger real-valued
|
||||||
loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}`
|
loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}`
|
||||||
(note the conjugation of z), the negative of which is precisely the direction of steepest descent
|
(note the conjugation of z), the negative of which is precisely the direction of steepest descent
|
||||||
used in Gradient Descent algorithm. Thus, all the existing optimizers work out of
|
used in Gradient Descent algorithm. Thus, there is a viable path in making the existing optimizers
|
||||||
the box with complex parameters.
|
work out of the box with complex parameters.
|
||||||
- This convention matches TensorFlow's convention for complex
|
- This convention matches TensorFlow's convention for complex
|
||||||
differentiation, but is different from JAX (which computes
|
differentiation, but is different from JAX (which computes
|
||||||
:math:`\frac{\partial L}{\partial z}`).
|
:math:`\frac{\partial L}{\partial z}`).
|
||||||
|
Reference in New Issue
Block a user