Support numpy array in Tensor.__eq__ (#122249)

When the `other` arg of `Tensor.__eq__` is a numpy array, it is converted to a PyTorch tensor view of the numpy array, which is then given as the `other` arg to a `Tensor.eq` call

Fixes #119965
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122249
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2024-03-21 01:58:06 +00:00
committed by PyTorch MergeBot
parent bf18e967b4
commit b915877deb
2 changed files with 81 additions and 1 deletions

View File

@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import \
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta) (instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
from torch.testing._internal.common_dtype import all_types_and_complex_and from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing import make_tensor
# For testing handling NumPy objects and sending tensors to / accepting # For testing handling NumPy objects and sending tensors to / accepting
@ -497,6 +498,49 @@ class TestNumPyInterop(TestCase):
else: else:
self.assertTrue(t == a) self.assertTrue(t == a)
@onlyCPU
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
def test___eq__(self, device, dtype):
a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9)
b = a.clone().detach()
b_np = b.numpy()
# Check all elements equal
res_check = torch.ones_like(a, dtype=torch.bool)
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
# Check one element unequal
if dtype == torch.bool:
b[1][3] = not b[1][3]
else:
b[1][3] += 1
res_check[1][3] = False
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
# Check random elements unequal
rand = torch.randint(0, 2, a.shape, dtype=torch.bool)
res_check = rand.logical_not()
b.copy_(a)
if dtype == torch.bool:
b[rand] = b[rand].logical_not()
else:
b[rand] += 1
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
# Check all elements unequal
if dtype == torch.bool:
b.copy_(a.logical_not())
else:
b.copy_(a + 1)
res_check.fill_(False)
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
@onlyCPU @onlyCPU
def test_empty_tensors_interop(self, device): def test_empty_tensors_interop(self, device):
x = torch.rand((), dtype=torch.float16) x = torch.rand((), dtype=torch.float16)

View File

@ -22,6 +22,7 @@
#include "torch/csrc/cuda/Event.h" #include "torch/csrc/cuda/Event.h"
#endif #endif
#include "torch/csrc/utils/device_lazy_init.h" #include "torch/csrc/utils/device_lazy_init.h"
#include <torch/csrc/utils/numpy_stub.h>
#include "torch/csrc/utils/object_ptr.h" #include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/pycfunction_helpers.h" #include "torch/csrc/utils/pycfunction_helpers.h"
#include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/python_arg_parser.h"
@ -1078,6 +1079,41 @@ static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) {
return THPVariable_is_nonzero(self, args); return THPVariable_is_nonzero(self, args);
} }
static PyObject * THPVariable___eq__(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
#ifdef USE_NUMPY
if (torch::utils::is_numpy_available()) {
static PythonArgParser parser({
"__eq__(PyObject* other)",
}, /*traceable=*/true);
ParsedArgs<1> parsed_args;
auto _r = parser.parse(self_, args, kwargs, parsed_args);
if(_r.has_torch_function()) {
return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor");
}
switch (_r.idx) {
case 0: {
auto other = _r.pyobject(0);
if (PyArray_Check(other)) {
auto other_tensor = torch::utils::tensor_from_numpy(other);
auto dispatch_eq = [](const at::Tensor & self, const at::Tensor & other) -> at::Tensor {
pybind11::gil_scoped_release no_gil;
return self.eq(other);
};
const Tensor& self = THPVariable_Unpack(self_);
return wrap(dispatch_eq(self, other_tensor));
}
}
}
}
#endif
return THPVariable_eq(self_, args, kwargs);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// Wrapper converts a raised TypeError into returning NotImplemented // Wrapper converts a raised TypeError into returning NotImplemented
// Used to implement binary arithmetic operators // Used to implement binary arithmetic operators
template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)> template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)>
@ -1209,7 +1245,7 @@ PyMethodDef variable_methods[] = {
{"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_floor_divide_>), METH_VARARGS | METH_KEYWORDS, NULL}, {"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_floor_divide_>), METH_VARARGS | METH_KEYWORDS, NULL},
{"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder>), METH_VARARGS | METH_KEYWORDS, NULL}, {"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder>), METH_VARARGS | METH_KEYWORDS, NULL},
{"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder_>), METH_VARARGS | METH_KEYWORDS, NULL}, {"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder_>), METH_VARARGS | METH_KEYWORDS, NULL},
{"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_eq>), METH_VARARGS | METH_KEYWORDS, NULL}, {"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable___eq__>), METH_VARARGS | METH_KEYWORDS, NULL},
{"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_ne>), METH_VARARGS | METH_KEYWORDS, NULL}, {"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_ne>), METH_VARARGS | METH_KEYWORDS, NULL},
{"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_lt>), METH_VARARGS | METH_KEYWORDS, NULL}, {"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_lt>), METH_VARARGS | METH_KEYWORDS, NULL},
{"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_le>), METH_VARARGS | METH_KEYWORDS, NULL}, {"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_le>), METH_VARARGS | METH_KEYWORDS, NULL},