mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support numpy array in Tensor.__eq__
(#122249)
When the `other` arg of `Tensor.__eq__` is a numpy array, it is converted to a PyTorch tensor view of the numpy array, which is then given as the `other` arg to a `Tensor.eq` call Fixes #119965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122249 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf18e967b4
commit
b915877deb
@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import \
|
|||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
|
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
|
||||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||||
|
from torch.testing import make_tensor
|
||||||
|
|
||||||
|
|
||||||
# For testing handling NumPy objects and sending tensors to / accepting
|
# For testing handling NumPy objects and sending tensors to / accepting
|
||||||
@ -497,6 +498,49 @@ class TestNumPyInterop(TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertTrue(t == a)
|
self.assertTrue(t == a)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||||
|
def test___eq__(self, device, dtype):
|
||||||
|
a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9)
|
||||||
|
b = a.clone().detach()
|
||||||
|
b_np = b.numpy()
|
||||||
|
|
||||||
|
# Check all elements equal
|
||||||
|
res_check = torch.ones_like(a, dtype=torch.bool)
|
||||||
|
self.assertEqual(a == b_np, res_check)
|
||||||
|
self.assertEqual(b_np == a, res_check)
|
||||||
|
|
||||||
|
# Check one element unequal
|
||||||
|
if dtype == torch.bool:
|
||||||
|
b[1][3] = not b[1][3]
|
||||||
|
else:
|
||||||
|
b[1][3] += 1
|
||||||
|
res_check[1][3] = False
|
||||||
|
self.assertEqual(a == b_np, res_check)
|
||||||
|
self.assertEqual(b_np == a, res_check)
|
||||||
|
|
||||||
|
# Check random elements unequal
|
||||||
|
rand = torch.randint(0, 2, a.shape, dtype=torch.bool)
|
||||||
|
res_check = rand.logical_not()
|
||||||
|
b.copy_(a)
|
||||||
|
|
||||||
|
if dtype == torch.bool:
|
||||||
|
b[rand] = b[rand].logical_not()
|
||||||
|
else:
|
||||||
|
b[rand] += 1
|
||||||
|
|
||||||
|
self.assertEqual(a == b_np, res_check)
|
||||||
|
self.assertEqual(b_np == a, res_check)
|
||||||
|
|
||||||
|
# Check all elements unequal
|
||||||
|
if dtype == torch.bool:
|
||||||
|
b.copy_(a.logical_not())
|
||||||
|
else:
|
||||||
|
b.copy_(a + 1)
|
||||||
|
res_check.fill_(False)
|
||||||
|
self.assertEqual(a == b_np, res_check)
|
||||||
|
self.assertEqual(b_np == a, res_check)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_empty_tensors_interop(self, device):
|
def test_empty_tensors_interop(self, device):
|
||||||
x = torch.rand((), dtype=torch.float16)
|
x = torch.rand((), dtype=torch.float16)
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
#include "torch/csrc/cuda/Event.h"
|
#include "torch/csrc/cuda/Event.h"
|
||||||
#endif
|
#endif
|
||||||
#include "torch/csrc/utils/device_lazy_init.h"
|
#include "torch/csrc/utils/device_lazy_init.h"
|
||||||
|
#include <torch/csrc/utils/numpy_stub.h>
|
||||||
#include "torch/csrc/utils/object_ptr.h"
|
#include "torch/csrc/utils/object_ptr.h"
|
||||||
#include "torch/csrc/utils/pycfunction_helpers.h"
|
#include "torch/csrc/utils/pycfunction_helpers.h"
|
||||||
#include "torch/csrc/utils/python_arg_parser.h"
|
#include "torch/csrc/utils/python_arg_parser.h"
|
||||||
@ -1078,6 +1079,41 @@ static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) {
|
|||||||
return THPVariable_is_nonzero(self, args);
|
return THPVariable_is_nonzero(self, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject * THPVariable___eq__(PyObject* self_, PyObject* args, PyObject* kwargs)
|
||||||
|
{
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
#ifdef USE_NUMPY
|
||||||
|
if (torch::utils::is_numpy_available()) {
|
||||||
|
static PythonArgParser parser({
|
||||||
|
"__eq__(PyObject* other)",
|
||||||
|
}, /*traceable=*/true);
|
||||||
|
|
||||||
|
ParsedArgs<1> parsed_args;
|
||||||
|
auto _r = parser.parse(self_, args, kwargs, parsed_args);
|
||||||
|
if(_r.has_torch_function()) {
|
||||||
|
return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor");
|
||||||
|
}
|
||||||
|
switch (_r.idx) {
|
||||||
|
case 0: {
|
||||||
|
auto other = _r.pyobject(0);
|
||||||
|
if (PyArray_Check(other)) {
|
||||||
|
auto other_tensor = torch::utils::tensor_from_numpy(other);
|
||||||
|
auto dispatch_eq = [](const at::Tensor & self, const at::Tensor & other) -> at::Tensor {
|
||||||
|
pybind11::gil_scoped_release no_gil;
|
||||||
|
return self.eq(other);
|
||||||
|
};
|
||||||
|
const Tensor& self = THPVariable_Unpack(self_);
|
||||||
|
return wrap(dispatch_eq(self, other_tensor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return THPVariable_eq(self_, args, kwargs);
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
// Wrapper converts a raised TypeError into returning NotImplemented
|
// Wrapper converts a raised TypeError into returning NotImplemented
|
||||||
// Used to implement binary arithmetic operators
|
// Used to implement binary arithmetic operators
|
||||||
template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)>
|
template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)>
|
||||||
@ -1209,7 +1245,7 @@ PyMethodDef variable_methods[] = {
|
|||||||
{"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_floor_divide_>), METH_VARARGS | METH_KEYWORDS, NULL},
|
{"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_floor_divide_>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||||
{"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder>), METH_VARARGS | METH_KEYWORDS, NULL},
|
{"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||||
{"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder_>), METH_VARARGS | METH_KEYWORDS, NULL},
|
{"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_remainder_>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||||
{"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_eq>), METH_VARARGS | METH_KEYWORDS, NULL},
|
{"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable___eq__>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||||
{"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_ne>), METH_VARARGS | METH_KEYWORDS, NULL},
|
{"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_ne>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||||
{"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_lt>), METH_VARARGS | METH_KEYWORDS, NULL},
|
{"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_lt>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||||
{"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_le>), METH_VARARGS | METH_KEYWORDS, NULL},
|
{"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_<THPVariable_le>), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||||
|
Reference in New Issue
Block a user