From 4a50b6c4905acc9c00597e7ff69a74591a8ef13e Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 6 Oct 2021 15:41:15 -0700 Subject: [PATCH] fix cosine similarity dimensionality check (#66191) Summary: Fixes https://github.com/pytorch/pytorch/issues/66086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/66191 Reviewed By: dagitses, malfet Differential Revision: D31436997 Pulled By: ngimel fbshipit-source-id: 363556eea4e1696d928ae08320d298451c286b10 --- aten/src/ATen/native/Distance.cpp | 9 +++------ test/test_nn.py | 6 ------ torch/nn/functional.py | 14 ++++++-------- .../_internal/common_methods_invocations.py | 2 ++ 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index 7974840dd3f9..9105c831e993 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -240,14 +240,11 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c } Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) { - TORCH_CHECK(x1.ndimension() == x2.ndimension(), "cosine_similarity requires both inputs to have the same number of dimensions, but x1 has ", - x1.ndimension(), " and x2 has ", x2.ndimension()); - TORCH_CHECK(x1.ndimension() == 0 || x1.size(dim) == x2.size(dim), "cosine_similarity requires both inputs to have the same size at dimension ", dim, "but x1 has ", - x1.size(dim), " and x2 has ", x2.size(dim)); + auto common_size = at::infer_size_dimvector(x1.sizes(), x2.sizes()); auto commonDtype = at::result_type(x1, x2); TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype); - Tensor x1_ = x1.to(commonDtype); - Tensor x2_ = x2.to(commonDtype); + Tensor x1_ = x1.to(commonDtype).expand(common_size); + Tensor x2_ = x2.to(commonDtype).expand(common_size); // Follow scipy impl to improve numerical precision // Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x)) Tensor w12 = at::sum(x1_ * x2_, dim); diff --git a/test/test_nn.py b/test/test_nn.py index f8acfda22c05..8bf6ace9e196 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9708,12 +9708,6 @@ class TestNN(NNTestCase): self.assertEqual(input1.grad, torch.zeros_like(input1)) self.assertEqual(input2.grad, input1 * 1e8) - # Check error when inputs are not the same shape - input1 = torch.randn(2, 2, 1) - input2 = torch.randn(2, 1, 3) - with self.assertRaises(RuntimeError): - F.cosine_similarity(input1, input2) - # Check type promotion, issue #61454 input = torch.tensor(12.) out = F.cosine_similarity(input.to(torch.int8), input, dim=-1) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index e2866904514a..29e0c9fd8082 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4304,7 +4304,10 @@ cosine_similarity = _add_docstr( r""" cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor -Returns cosine similarity between x1 and x2, computed along dim. +Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable +to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is +squeezed (see :func:`torch.squeeze`), resulting in the +output tensor having 1 fewer dimension. .. math :: \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)} @@ -4313,16 +4316,11 @@ Supports :ref:`type promotion `. Args: x1 (Tensor): First input. - x2 (Tensor): Second input (with the same number of dimensions as x1, matching x1 size at dimension `dim`, - and broadcastable with x1 at other dimensions). - dim (int, optional): Dimension of vectors. Default: 1 + x2 (Tensor): Second input. + dim (int, optional): Dimension along which cosine similarity is computed. Default: 1 eps (float, optional): Small value to avoid division by zero. Default: 1e-8 -Shape: - - Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`. - - Output: :math:`(\ast_1, \ast_2)` - Example:: >>> input1 = torch.randn(100, 128) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 09a2813d4aaa..f5fd87ac1380 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1295,6 +1295,8 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs) # Test for Broadcasting yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2}) + yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) return list(generator())