[test/torch_np] Fix usages of deprecated NumPy 2.0 APIs in numpy_tests (#131909)

Migrates usages of deprecated APIs in NumPy-2.0 per [numpy-2.0 migration guide](https://numpy.org/devdocs/numpy_2_0_migration_guide.html#numpy-2-0-migration-guide).

I did a grep on the old API usages (see list below) and these were used only referenced in test files under `test/torch_np/numpy_tests/**/*.py`.

Specifically, migrates the usages of the following APIs:

1. `np.sctypes` → Access dtypes explicitly instead
2. `np.float_` → `np.float64`
3. `np.complex_` → `np.complex128`
4. `np.longcomplex` → `np.clongdouble`
5. `np.unicode_` → `np.str_`
6. `np.product` → `np.prod`
7. `np.cumproduct` → `np.cumprod`
8. `np.alltrue` → `np.all`
9. `np.sometrue` → `np.any`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131909
Approved by: https://github.com/rgommers, https://github.com/Skylion007, https://github.com/atalman
This commit is contained in:
Kiuk Chung
2024-08-05 16:21:08 +00:00
committed by PyTorch MergeBot
parent a672f6c84e
commit d532c00c81
9 changed files with 67 additions and 35 deletions

View File

@ -215,7 +215,15 @@ class TestMisc(TestCase):
@xpassIfTorchDynamo # (reason="None of nmant, minexp, maxexp is implemented.") @xpassIfTorchDynamo # (reason="None of nmant, minexp, maxexp is implemented.")
def test_plausible_finfo(self): def test_plausible_finfo(self):
# Assert that finfo returns reasonable results for all types # Assert that finfo returns reasonable results for all types
for ftype in np.sctypes["float"] + np.sctypes["complex"]: for ftype in (
[np.float16, np.float32, np.float64, np.longdouble]
+ [
np.complex64,
np.complex128,
]
# no complex256 in torch._numpy
+ ([np.clongdouble] if hasattr(np, "clongdouble") else [])
):
info = np.finfo(ftype) info = np.finfo(ftype)
assert_(info.nmant > 1) assert_(info.nmant > 1)
assert_(info.minexp < -1) assert_(info.minexp < -1)

View File

@ -880,7 +880,7 @@ class TestMultiIndexingAutomated(TestCase):
if np.any(_indx >= _size) or np.any(_indx < -_size): if np.any(_indx >= _size) or np.any(_indx < -_size):
raise IndexError raise IndexError
if len(indx[1:]) == len(orig_slice): if len(indx[1:]) == len(orig_slice):
if np.product(orig_slice) == 0: if np.prod(orig_slice) == 0:
# Work around for a crash or IndexError with 'wrap' # Work around for a crash or IndexError with 'wrap'
# in some 0-sized cases. # in some 0-sized cases.
try: try:
@ -1092,7 +1092,7 @@ class TestFloatNonIntegerArgument(TestCase):
def mult(a, b): def mult(a, b):
return a * b return a * b
assert_raises(TypeError, mult, [1], np.float_(3)) assert_raises(TypeError, mult, [1], np.float64(3))
# following should be OK # following should be OK
mult([1], np.int_(3)) mult([1], np.int_(3))

View File

@ -373,7 +373,7 @@ class TestAttributes(TestCase):
def test_dtypeattr(self): def test_dtypeattr(self):
assert_equal(self.one.dtype, np.dtype(np.int_)) assert_equal(self.one.dtype, np.dtype(np.int_))
assert_equal(self.three.dtype, np.dtype(np.float_)) assert_equal(self.three.dtype, np.dtype(np.float64))
assert_equal(self.one.dtype.char, "l") assert_equal(self.one.dtype.char, "l")
assert_equal(self.three.dtype.char, "d") assert_equal(self.three.dtype.char, "d")
assert_(self.three.dtype.str[0] in "<>") assert_(self.three.dtype.str[0] in "<>")
@ -690,12 +690,15 @@ class TestAssignment(TestCase):
assert_raises(ValueError, operator.setitem, u, 0, bad_sequence()) assert_raises(ValueError, operator.setitem, u, 0, bad_sequence())
assert_raises(ValueError, operator.setitem, b, 0, bad_sequence()) assert_raises(ValueError, operator.setitem, b, 0, bad_sequence())
@skip(reason="longdouble") @skipif(
"torch._numpy" == np.__name__,
reason="torch._numpy does not support extended floats and complex dtypes",
)
def test_longdouble_assignment(self): def test_longdouble_assignment(self):
# only relevant if longdouble is larger than float # only relevant if longdouble is larger than float
# we're looking for loss of precision # we're looking for loss of precision
for dtype in (np.longdouble, np.longcomplex): for dtype in (np.longdouble, np.clongdouble):
# gh-8902 # gh-8902
tinyb = np.nextafter(np.longdouble(0), 1).astype(dtype) tinyb = np.nextafter(np.longdouble(0), 1).astype(dtype)
tinya = np.nextafter(np.longdouble(0), -1).astype(dtype) tinya = np.nextafter(np.longdouble(0), -1).astype(dtype)
@ -1396,7 +1399,7 @@ class TestBool(TestCase):
@xfail # (reason="See gh-9847") @xfail # (reason="See gh-9847")
def test_cast_from_unicode(self): def test_cast_from_unicode(self):
self._test_cast_from_flexible(np.unicode_) self._test_cast_from_flexible(np.str_)
@xfail # (reason="See gh-9847") @xfail # (reason="See gh-9847")
def test_cast_from_bytes(self): def test_cast_from_bytes(self):
@ -1827,7 +1830,7 @@ class TestMethods(TestCase):
a = np.array(["aaaaaaaaa" for i in range(100)]) a = np.array(["aaaaaaaaa" for i in range(100)])
assert_equal(a.argsort(kind="m"), r) assert_equal(a.argsort(kind="m"), r)
# unicode # unicode
a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.unicode_) a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.str_)
assert_equal(a.argsort(kind="m"), r) assert_equal(a.argsort(kind="m"), r)
@xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch") @xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch")
@ -3486,6 +3489,16 @@ class TestNewaxis(TestCase):
assert_almost_equal(res.ravel(), 250 * sk) assert_almost_equal(res.ravel(), 250 * sk)
_sctypes = {
"int": [np.int8, np.int16, np.int32, np.int64],
"uint": [np.uint8, np.uint16, np.uint32, np.uint64],
"float": [np.float32, np.float64],
"complex": [np.complex64, np.complex128]
# no complex256 in torch._numpy
+ ([np.clongdouble] if hasattr(np, "clongdouble") else []),
}
class TestClip(TestCase): class TestClip(TestCase):
def _check_range(self, x, cmin, cmax): def _check_range(self, x, cmin, cmax):
assert_(np.all(x >= cmin)) assert_(np.all(x >= cmin))
@ -3506,7 +3519,7 @@ class TestClip(TestCase):
if expected_max is None: if expected_max is None:
expected_max = clip_max expected_max = clip_max
for T in np.sctypes[type_group]: for T in _sctypes[type_group]:
if sys.byteorder == "little": if sys.byteorder == "little":
byte_orders = ["=", ">"] byte_orders = ["=", ">"]
else: else:
@ -6410,7 +6423,7 @@ class TestConversion(TestCase):
# gh-9972 # gh-9972
assert_equal(4, int_func(np.array("4"))) assert_equal(4, int_func(np.array("4")))
assert_equal(5, int_func(np.bytes_(b"5"))) assert_equal(5, int_func(np.bytes_(b"5")))
assert_equal(6, int_func(np.unicode_("6"))) assert_equal(6, int_func(np.str_("6")))
# The delegation of int() to __trunc__ was deprecated in # The delegation of int() to __trunc__ was deprecated in
# Python 3.11. # Python 3.11.

View File

@ -146,7 +146,7 @@ class TestNonarrayArgs(TestCase):
def test_cumproduct(self): def test_cumproduct(self):
A = [[1, 2, 3], [4, 5, 6]] A = [[1, 2, 3], [4, 5, 6]]
assert_(np.all(np.cumproduct(A) == np.array([1, 2, 6, 24, 120, 720]))) assert_(np.all(np.cumprod(A) == np.array([1, 2, 6, 24, 120, 720])))
def test_diagonal(self): def test_diagonal(self):
a = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] a = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
@ -701,7 +701,7 @@ class TestFloatExceptions(TestCase):
@parametrize("typecode", np.typecodes["AllFloat"]) @parametrize("typecode", np.typecodes["AllFloat"])
def test_floating_exceptions(self, typecode): def test_floating_exceptions(self, typecode):
# Test basic arithmetic function errors # Test basic arithmetic function errors
ftype = np.obj2sctype(typecode) ftype = np.dtype(typecode).type
if np.dtype(ftype).kind == "f": if np.dtype(ftype).kind == "f":
# Get some extreme values for the type # Get some extreme values for the type
fi = np.finfo(ftype) fi = np.finfo(ftype)
@ -924,14 +924,19 @@ class TestTypes(TestCase):
@xpassIfTorchDynamo # (reason="value-based casting?") @xpassIfTorchDynamo # (reason="value-based casting?")
def test_can_cast_values(self): def test_can_cast_values(self):
# gh-5917 # gh-5917
for dt in np.sctypes["int"] + np.sctypes["uint"]: for dt in [np.int8, np.int16, np.int32, np.int64] + [
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]:
ii = np.iinfo(dt) ii = np.iinfo(dt)
assert_(np.can_cast(ii.min, dt)) assert_(np.can_cast(ii.min, dt))
assert_(np.can_cast(ii.max, dt)) assert_(np.can_cast(ii.max, dt))
assert_(not np.can_cast(ii.min - 1, dt)) assert_(not np.can_cast(ii.min - 1, dt))
assert_(not np.can_cast(ii.max + 1, dt)) assert_(not np.can_cast(ii.max + 1, dt))
for dt in np.sctypes["float"]: for dt in [np.float16, np.float32, np.float64, np.longdouble]:
fi = np.finfo(dt) fi = np.finfo(dt)
assert_(np.can_cast(fi.min, dt)) assert_(np.can_cast(fi.min, dt))
assert_(np.can_cast(fi.max, dt)) assert_(np.can_cast(fi.max, dt))
@ -969,8 +974,8 @@ class TestFromiter(TestCase):
expected = np.array(list(self.makegen())) expected = np.array(list(self.makegen()))
a = np.fromiter(self.makegen(), int) a = np.fromiter(self.makegen(), int)
a20 = np.fromiter(self.makegen(), int, 20) a20 = np.fromiter(self.makegen(), int, 20)
assert_(np.alltrue(a == expected, axis=0)) assert_(np.all(a == expected, axis=0))
assert_(np.alltrue(a20 == expected[:20], axis=0)) assert_(np.all(a20 == expected[:20], axis=0))
def load_data(self, n, eindex): def load_data(self, n, eindex):
# Utility method for the issue 2592 tests. # Utility method for the issue 2592 tests.
@ -2159,7 +2164,6 @@ class TestCreationFuncs(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
# dtypes = {np.dtype(tp) for tp in itertools.chain.from_iterable(np.sctypes.values())}
dtypes = {np.dtype(tp) for tp in "efdFDBbhil?"} dtypes = {np.dtype(tp) for tp in "efdFDBbhil?"}
self.dtypes = dtypes self.dtypes = dtypes
self.orders = { self.orders = {

View File

@ -238,7 +238,11 @@ class TestClassGetitemMisc(TestCase):
class TestBitCount(TestCase): class TestBitCount(TestCase):
# derived in part from the cpython test "test_bit_count" # derived in part from the cpython test "test_bit_count"
@parametrize("itype", np.sctypes["int"] + np.sctypes["uint"]) @parametrize(
"itype",
[np.int8, np.int16, np.int32, np.int64]
+ [np.uint8, np.uint16, np.uint32, np.uint64],
)
def test_small(self, itype): def test_small(self, itype):
for a in range(max(np.iinfo(itype).min, 0), 128): for a in range(max(np.iinfo(itype).min, 0), 128):
msg = f"Smoke test for {itype}({a}).bit_count()" msg = f"Smoke test for {itype}({a}).bit_count()"

View File

@ -123,7 +123,7 @@ def _make_complex(real, imag):
Like real + 1j * imag, but behaves as expected when imag contains non-finite Like real + 1j * imag, but behaves as expected when imag contains non-finite
values values
""" """
ret = np.zeros(np.broadcast(real, imag).shape, np.complex_) ret = np.zeros(np.broadcast(real, imag).shape, np.complex128)
ret.real = real ret.real = real
ret.imag = imag ret.imag = imag
return ret return ret
@ -264,8 +264,8 @@ class TestAny(TestCase):
def test_nd(self): def test_nd(self):
y1 = [[0, 0, 0], [0, 1, 0], [1, 1, 0]] y1 = [[0, 0, 0], [0, 1, 0], [1, 1, 0]]
assert_(np.any(y1)) assert_(np.any(y1))
assert_array_equal(np.sometrue(y1, axis=0), [1, 1, 0]) assert_array_equal(np.any(y1, axis=0), [1, 1, 0])
assert_array_equal(np.sometrue(y1, axis=1), [0, 1, 1]) assert_array_equal(np.any(y1, axis=1), [0, 1, 1])
class TestAll(TestCase): class TestAll(TestCase):
@ -281,8 +281,8 @@ class TestAll(TestCase):
def test_nd(self): def test_nd(self):
y1 = [[0, 0, 1], [0, 1, 1], [1, 1, 1]] y1 = [[0, 0, 1], [0, 1, 1], [1, 1, 1]]
assert_(not np.all(y1)) assert_(not np.all(y1))
assert_array_equal(np.alltrue(y1, axis=0), [0, 0, 1]) assert_array_equal(np.all(y1, axis=0), [0, 0, 1])
assert_array_equal(np.alltrue(y1, axis=1), [0, 0, 1]) assert_array_equal(np.all(y1, axis=1), [0, 0, 1])
class TestCopy(TestCase): class TestCopy(TestCase):
@ -492,7 +492,7 @@ class TestSelect(TestCase):
assert_equal(select([True], [0], default=[0]).shape, (1,)) assert_equal(select([True], [0], default=[0]).shape, (1,))
def test_return_dtype(self): def test_return_dtype(self):
assert_equal(select(self.conditions, self.choices, 1j).dtype, np.complex_) assert_equal(select(self.conditions, self.choices, 1j).dtype, np.complex128)
# But the conditions need to be stronger then the scalar default # But the conditions need to be stronger then the scalar default
# if it is scalar. # if it is scalar.
choices = [choice.astype(np.int8) for choice in self.choices] choices = [choice.astype(np.int8) for choice in self.choices]
@ -2603,7 +2603,7 @@ class TestBincount(TestCase):
parametrize_interp_sc = parametrize( parametrize_interp_sc = parametrize(
"sc", "sc",
[ [
subtest(lambda x: np.float_(x), name="real"), subtest(lambda x: np.float64(x), name="real"),
subtest(lambda x: _make_complex(x, 0), name="complex-real"), subtest(lambda x: _make_complex(x, 0), name="complex-real"),
subtest(lambda x: _make_complex(0, x), name="complex-imag"), subtest(lambda x: _make_complex(0, x), name="complex-imag"),
subtest(lambda x: _make_complex(x, np.multiply(x, -2)), name="complex-both"), subtest(lambda x: _make_complex(x, np.multiply(x, -2)), name="complex-both"),
@ -2859,9 +2859,9 @@ class TestPercentile(TestCase):
@parametrize("dtype", np.typecodes["Float"]) @parametrize("dtype", np.typecodes["Float"])
def test_linear_nan_1D(self, dtype): def test_linear_nan_1D(self, dtype):
# METHOD 1 of H&F # METHOD 1 of H&F
arr = np.asarray([15.0, np.NAN, 35.0, 40.0, 50.0], dtype=dtype) arr = np.asarray([15.0, np.nan, 35.0, 40.0, 50.0], dtype=dtype)
res = np.percentile(arr, 40.0, method="linear") res = np.percentile(arr, 40.0, method="linear")
np.testing.assert_equal(res, np.NAN) np.testing.assert_equal(res, np.nan)
np.testing.assert_equal(res.dtype, arr.dtype) np.testing.assert_equal(res.dtype, arr.dtype)
H_F_TYPE_CODES = [ H_F_TYPE_CODES = [

View File

@ -204,7 +204,7 @@ class TestIscomplex(TestCase):
def test_fail(self): def test_fail(self):
z = np.array([-1, 0, 1]) z = np.array([-1, 0, 1])
res = iscomplex(z) res = iscomplex(z)
assert_(not np.sometrue(res, axis=0)) assert_(not np.any(res, axis=0))
def test_pass(self): def test_pass(self):
z = np.array([-1j, 1, 0]) z = np.array([-1j, 1, 0])
@ -389,19 +389,19 @@ class TestNanToNum(TestCase):
def test_float(self): def test_float(self):
vals = nan_to_num(1.0) vals = nan_to_num(1.0)
assert_all(vals == 1.0) assert_all(vals == 1.0)
assert_equal(type(vals), np.float_) assert_equal(type(vals), np.float64)
vals = nan_to_num(1.1, nan=10, posinf=20, neginf=30) vals = nan_to_num(1.1, nan=10, posinf=20, neginf=30)
assert_all(vals == 1.1) assert_all(vals == 1.1)
assert_equal(type(vals), np.float_) assert_equal(type(vals), np.float64)
@skip(reason="we return OD arrays not scalars") @skip(reason="we return OD arrays not scalars")
def test_complex_good(self): def test_complex_good(self):
vals = nan_to_num(1 + 1j) vals = nan_to_num(1 + 1j)
assert_all(vals == 1 + 1j) assert_all(vals == 1 + 1j)
assert isinstance(vals, np.complex_) assert isinstance(vals, np.complex128)
vals = nan_to_num(1 + 1j, nan=10, posinf=20, neginf=30) vals = nan_to_num(1 + 1j, nan=10, posinf=20, neginf=30)
assert_all(vals == 1 + 1j) assert_all(vals == 1 + 1j)
assert_equal(type(vals), np.complex_) assert_equal(type(vals), np.complex128)
@skip(reason="we return OD arrays not scalars") @skip(reason="we return OD arrays not scalars")
def test_complex_bad(self): def test_complex_bad(self):
@ -410,7 +410,7 @@ class TestNanToNum(TestCase):
vals = nan_to_num(v) vals = nan_to_num(v)
# !! This is actually (unexpectedly) zero # !! This is actually (unexpectedly) zero
assert_all(np.isfinite(vals)) assert_all(np.isfinite(vals))
assert_equal(type(vals), np.complex_) assert_equal(type(vals), np.complex128)
@skip(reason="we return OD arrays not scalars") @skip(reason="we return OD arrays not scalars")
def test_complex_bad2(self): def test_complex_bad2(self):
@ -418,7 +418,7 @@ class TestNanToNum(TestCase):
v += np.array(-1 + 1.0j) / 0.0 v += np.array(-1 + 1.0j) / 0.0
vals = nan_to_num(v) vals = nan_to_num(v)
assert_all(np.isfinite(vals)) assert_all(np.isfinite(vals))
assert_equal(type(vals), np.complex_) assert_equal(type(vals), np.complex128)
# Fixme # Fixme
# assert_all(vals.imag > 1e10) and assert_all(np.isfinite(vals)) # assert_all(vals.imag > 1e10) and assert_all(np.isfinite(vals))
# !! This is actually (unexpectedly) positive # !! This is actually (unexpectedly) positive

View File

@ -851,7 +851,7 @@ class TestCond(CondCases, TestCase):
A[0, 1] = np.nan A[0, 1] = np.nan
for p in ps: for p in ps:
c = linalg.cond(A, p) c = linalg.cond(A, p)
assert_(isinstance(c, np.float_)) assert_(isinstance(c, np.float64))
assert_(np.isnan(c)) assert_(np.isnan(c))
A = np.ones((3, 2, 2)) A = np.ones((3, 2, 2))

View File

@ -21,7 +21,10 @@ from ._util import AxisError, UFuncTypeError
from math import pi, e # usort: skip from math import pi, e # usort: skip
all = all
alltrue = all alltrue = all
any = any
sometrue = any sometrue = any
inf = float("inf") inf = float("inf")