mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support numpy array in Tensor.__eq__
(#122249)
When the `other` arg of `Tensor.__eq__` is a numpy array, it is converted to a PyTorch tensor view of the numpy array, which is then given as the `other` arg to a `Tensor.eq` call Fixes #119965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122249 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf18e967b4
commit
b915877deb
@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import \
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing import make_tensor
|
||||
|
||||
|
||||
# For testing handling NumPy objects and sending tensors to / accepting
|
||||
@ -497,6 +498,49 @@ class TestNumPyInterop(TestCase):
|
||||
else:
|
||||
self.assertTrue(t == a)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
def test___eq__(self, device, dtype):
|
||||
a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9)
|
||||
b = a.clone().detach()
|
||||
b_np = b.numpy()
|
||||
|
||||
# Check all elements equal
|
||||
res_check = torch.ones_like(a, dtype=torch.bool)
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
# Check one element unequal
|
||||
if dtype == torch.bool:
|
||||
b[1][3] = not b[1][3]
|
||||
else:
|
||||
b[1][3] += 1
|
||||
res_check[1][3] = False
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
# Check random elements unequal
|
||||
rand = torch.randint(0, 2, a.shape, dtype=torch.bool)
|
||||
res_check = rand.logical_not()
|
||||
b.copy_(a)
|
||||
|
||||
if dtype == torch.bool:
|
||||
b[rand] = b[rand].logical_not()
|
||||
else:
|
||||
b[rand] += 1
|
||||
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
# Check all elements unequal
|
||||
if dtype == torch.bool:
|
||||
b.copy_(a.logical_not())
|
||||
else:
|
||||
b.copy_(a + 1)
|
||||
res_check.fill_(False)
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
@onlyCPU
|
||||
def test_empty_tensors_interop(self, device):
|
||||
x = torch.rand((), dtype=torch.float16)
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "torch/csrc/cuda/Event.h"
|
||||
#endif
|
||||
#include "torch/csrc/utils/device_lazy_init.h"
|
||||
#include <torch/csrc/utils/numpy_stub.h>
|
||||
#include "torch/csrc/utils/object_ptr.h"
|
||||
#include "torch/csrc/utils/pycfunction_helpers.h"
|
||||
#include "torch/csrc/utils/python_arg_parser.h"
|
||||
@ -1078,6 +1079,41 @@ static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) {
|
||||
return THPVariable_is_nonzero(self, args);
|
||||
}
|
||||
|
||||
static PyObject * THPVariable___eq__(PyObject* self_, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#ifdef USE_NUMPY
|
||||
if (torch::utils::is_numpy_available()) {
|
||||
static PythonArgParser parser({
|
||||
"__eq__(PyObject* other)",
|
||||
}, /*traceable=*/true);
|
||||
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto _r = parser.parse(self_, args, kwargs, parsed_args);
|
||||
if(_r.has_torch_function()) {
|
||||
return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor");
|
||||
}
|
||||
switch (_r.idx) {
|
||||
case 0: {
|
||||
auto other = _r.pyobject(0);
|
||||
if (PyArray_Check(other)) {
|
||||
auto other_tensor = torch::utils::tensor_from_numpy(other);
|
||||
auto dispatch_eq = [](const at::Tensor & self, const at::Tensor & other) -> at::Tensor {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
return self.eq(other);
|
||||
};
|
||||
const Tensor& self = THPVariable_Unpack(self_);
|
||||
return wrap(dispatch_eq(self, other_tensor));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return THPVariable_eq(self_, args, kwargs);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// Wrapper converts a raised TypeError into returning NotImplemented
|
||||
// Used to implement binary arithmetic operators
|
||||
template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)>
|
||||
@ -1209,7 +1245,7 @@ PyMethodDef variable_methods[] = {
|
||||
{"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_floor_divide_>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder_>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_eq>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable___eq__>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_ne>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_lt>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_le>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
|
Reference in New Issue
Block a user