[Array API] Add linalg.vecdot (#70542)

This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in https://github.com/data-apis/array-api/issues/356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves https://github.com/pytorch/pytorch/issues/18027.

cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70542
Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
This commit is contained in:
lezcano
2022-07-06 11:02:33 +00:00
committed by PyTorch MergeBot
parent 6a7ed56d79
commit 74208a9c68
9 changed files with 118 additions and 6 deletions

View File

@ -3510,21 +3510,34 @@ add_docstr(torch.vdot,
r"""
vdot(input, other, *, out=None) -> Tensor
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
first argument is used for the calculation of the dot product.
Computes the dot product of two 1D vectors along a dimension.
In symbols, this function computes
.. math::
\sum_{i=1}^n \overline{x_i}y_i.
where :math:`\overline{x_i}` denotes the conjugate for complex
vectors, and it is the identity for real vectors.
.. note::
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
of two 1D tensors with the same number of elements.
.. seealso::
:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.
Args:
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
other (Tensor): second tensor in the dot product, must be 1D.
Keyword args:
{out}
""" + fr"""
.. note:: {common_args["out"]}
""" + r"""
Example::
@ -3536,7 +3549,7 @@ Example::
tensor([16.+1.j])
>>> torch.vdot(b, a)
tensor([16.-1.j])
""".format(**common_args))
""")
add_docstr(torch.eig,
r"""