mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
fix cosine similarity dimensionality check (#66191)
Summary: Fixes https://github.com/pytorch/pytorch/issues/66086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/66191 Reviewed By: dagitses, malfet Differential Revision: D31436997 Pulled By: ngimel fbshipit-source-id: 363556eea4e1696d928ae08320d298451c286b10
This commit is contained in:
committed by
Facebook GitHub Bot
parent
05e1476d49
commit
4a50b6c490
@ -240,14 +240,11 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
|
||||
}
|
||||
|
||||
Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
|
||||
TORCH_CHECK(x1.ndimension() == x2.ndimension(), "cosine_similarity requires both inputs to have the same number of dimensions, but x1 has ",
|
||||
x1.ndimension(), " and x2 has ", x2.ndimension());
|
||||
TORCH_CHECK(x1.ndimension() == 0 || x1.size(dim) == x2.size(dim), "cosine_similarity requires both inputs to have the same size at dimension ", dim, "but x1 has ",
|
||||
x1.size(dim), " and x2 has ", x2.size(dim));
|
||||
auto common_size = at::infer_size_dimvector(x1.sizes(), x2.sizes());
|
||||
auto commonDtype = at::result_type(x1, x2);
|
||||
TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);
|
||||
Tensor x1_ = x1.to(commonDtype);
|
||||
Tensor x2_ = x2.to(commonDtype);
|
||||
Tensor x1_ = x1.to(commonDtype).expand(common_size);
|
||||
Tensor x2_ = x2.to(commonDtype).expand(common_size);
|
||||
// Follow scipy impl to improve numerical precision
|
||||
// Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
|
||||
Tensor w12 = at::sum(x1_ * x2_, dim);
|
||||
|
@ -9708,12 +9708,6 @@ class TestNN(NNTestCase):
|
||||
self.assertEqual(input1.grad, torch.zeros_like(input1))
|
||||
self.assertEqual(input2.grad, input1 * 1e8)
|
||||
|
||||
# Check error when inputs are not the same shape
|
||||
input1 = torch.randn(2, 2, 1)
|
||||
input2 = torch.randn(2, 1, 3)
|
||||
with self.assertRaises(RuntimeError):
|
||||
F.cosine_similarity(input1, input2)
|
||||
|
||||
# Check type promotion, issue #61454
|
||||
input = torch.tensor(12.)
|
||||
out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
|
||||
|
@ -4304,7 +4304,10 @@ cosine_similarity = _add_docstr(
|
||||
r"""
|
||||
cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
|
||||
|
||||
Returns cosine similarity between x1 and x2, computed along dim.
|
||||
Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
|
||||
to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
|
||||
squeezed (see :func:`torch.squeeze`), resulting in the
|
||||
output tensor having 1 fewer dimension.
|
||||
|
||||
.. math ::
|
||||
\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
|
||||
@ -4313,16 +4316,11 @@ Supports :ref:`type promotion <type-promotion-doc>`.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): First input.
|
||||
x2 (Tensor): Second input (with the same number of dimensions as x1, matching x1 size at dimension `dim`,
|
||||
and broadcastable with x1 at other dimensions).
|
||||
dim (int, optional): Dimension of vectors. Default: 1
|
||||
x2 (Tensor): Second input.
|
||||
dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
|
||||
eps (float, optional): Small value to avoid division by zero.
|
||||
Default: 1e-8
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
|
||||
- Output: :math:`(\ast_1, \ast_2)`
|
||||
|
||||
Example::
|
||||
|
||||
>>> input1 = torch.randn(100, 128)
|
||||
|
@ -1295,6 +1295,8 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa
|
||||
yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
|
||||
# Test for Broadcasting
|
||||
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
|
||||
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
|
||||
yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
|
||||
|
||||
return list(generator())
|
||||
|
||||
|
Reference in New Issue
Block a user