mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[Array API] Add linalg.vecdot (#70542)
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in https://github.com/data-apis/array-api/issues/356 Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this. Resolves https://github.com/pytorch/pytorch/issues/18027. cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi Pull Request resolved: https://github.com/pytorch/pytorch/pull/70542 Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
6a7ed56d79
commit
74208a9c68
@ -3510,21 +3510,34 @@ add_docstr(torch.vdot,
|
||||
r"""
|
||||
vdot(input, other, *, out=None) -> Tensor
|
||||
|
||||
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
|
||||
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
|
||||
first argument is used for the calculation of the dot product.
|
||||
Computes the dot product of two 1D vectors along a dimension.
|
||||
|
||||
In symbols, this function computes
|
||||
|
||||
.. math::
|
||||
|
||||
\sum_{i=1}^n \overline{x_i}y_i.
|
||||
|
||||
where :math:`\overline{x_i}` denotes the conjugate for complex
|
||||
vectors, and it is the identity for real vectors.
|
||||
|
||||
.. note::
|
||||
|
||||
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
|
||||
of two 1D tensors with the same number of elements.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.
|
||||
|
||||
Args:
|
||||
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
|
||||
other (Tensor): second tensor in the dot product, must be 1D.
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
""" + fr"""
|
||||
.. note:: {common_args["out"]}
|
||||
""" + r"""
|
||||
|
||||
Example::
|
||||
|
||||
@ -3536,7 +3549,7 @@ Example::
|
||||
tensor([16.+1.j])
|
||||
>>> torch.vdot(b, a)
|
||||
tensor([16.-1.j])
|
||||
""".format(**common_args))
|
||||
""")
|
||||
|
||||
add_docstr(torch.eig,
|
||||
r"""
|
||||
|
Reference in New Issue
Block a user