Files
pytorch/test/torch_np/test_nep50_examples.py
Yuanyuan Chen 8de85896e0 Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162
Approved by: https://github.com/Skylion007
2025-10-13 01:48:55 +00:00

221 lines
6.6 KiB
Python

# Owner(s): ["module: dynamo"]
"""Test examples for NEP 50."""
import itertools
from unittest import skipIf as skipif, SkipTest
try:
import numpy as _np
v = _np.__version__.split(".")
HAVE_NUMPY = int(v[0]) >= 1 and int(v[1]) >= 24
except ImportError:
HAVE_NUMPY = False
import torch._numpy as tnp
from torch._numpy import ( # noqa: F401
array,
bool_,
complex128,
complex64,
float32,
float64,
inf,
int16,
int32,
int64,
uint8,
)
from torch._numpy.testing import assert_allclose
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
uint16 = uint8 # can be anything here, see below
# from numpy import array, uint8, uint16, int64, float32, float64, inf
# from numpy.testing import assert_allclose
# import numpy as np
# np._set_promotion_state('weak')
from pytest import raises as assert_raises
unchanged = None
# expression old result new_result
examples = {
"uint8(1) + 2": (int64(3), uint8(3)),
"array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)),
"array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)),
"array([1.], float32) + float64(1.)": (
array([2.0], float32),
array([2.0], float64),
),
"array([1.], float32) + array(1., float64)": (
array([2.0], float32),
array([2.0], float64),
),
"array([1], uint8) + 1": (array([2], uint8), unchanged),
"array([1], uint8) + 200": (array([201], uint8), unchanged),
"array([100], uint8) + 200": (array([44], uint8), unchanged),
"array([1], uint8) + 300": (array([301], uint16), Exception),
"uint8(1) + 300": (int64(301), Exception),
"uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning
"float32(1) + 3e100": (float64(3e100), float32(inf)), # and RuntimeWarning [T7]
"array([1.0], float32) + 1e-14 == 1.0": (array([True]), unchanged),
"array([0.1], float32) == float64(0.1)": (array([True]), array([False])),
"array(1.0, float32) + 1e-14 == 1.0": (array(False), array(True)),
"array([1.], float32) + 3": (array([4.0], float32), unchanged),
"array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)),
"3j + array(3, complex64)": (array(3 + 3j, complex128), array(3 + 3j, complex64)),
"float32(1) + 1j": (array(1 + 1j, complex128), array(1 + 1j, complex64)),
"int32(1) + 5j": (array(1 + 5j, complex128), unchanged),
# additional examples from the NEP text
"int16(2) + 2": (int64(4), int16(4)),
"int16(4) + 4j": (complex128(4 + 4j), unchanged),
"float32(5) + 5j": (complex128(5 + 5j), complex64(5 + 5j)),
"bool_(True) + 1": (int64(2), unchanged),
"True + uint8(2)": (uint8(3), unchanged),
}
@skipif(not HAVE_NUMPY, reason="NumPy not found")
@instantiate_parametrized_tests
class TestNEP50Table(TestCase):
@parametrize("example", examples)
def test_nep50_exceptions(self, example):
old, new = examples[example]
if new is Exception:
with assert_raises(OverflowError):
eval(example)
else:
result = eval(example)
if new is unchanged:
new = old
assert_allclose(result, new, atol=1e-16)
assert result.dtype == new.dtype
# ### Directly compare to numpy ###
weaks = (True, 1, 2.0, 3j)
non_weaks = (
tnp.asarray(True),
tnp.uint8(1),
tnp.int8(1),
tnp.int32(1),
tnp.int64(1),
tnp.float32(1),
tnp.float64(1),
tnp.complex64(1),
tnp.complex128(1),
)
if HAVE_NUMPY:
dtypes = (
None,
_np.bool_,
_np.uint8,
_np.int8,
_np.int32,
_np.int64,
_np.float32,
_np.float64,
_np.complex64,
_np.complex128,
)
else:
dtypes = (None,)
# ufunc name: [array.dtype]
corners = {
"true_divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
"divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
"arctan2": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
"copysign": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
"heaviside": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
"ldexp": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
"power": ["uint8"],
"nextafter": ["float32"],
}
@skipif(not HAVE_NUMPY, reason="NumPy not found")
@instantiate_parametrized_tests
class TestCompareToNumpy(TestCase):
@parametrize("scalar, array, dtype", itertools.product(weaks, non_weaks, dtypes))
def test_direct_compare(self, scalar, array, dtype):
# compare to NumPy w/ NEP 50.
try:
state = _np._get_promotion_state()
_np._set_promotion_state("weak")
if dtype is not None:
kwargs = {"dtype": dtype}
try:
result_numpy = _np.add(scalar, array.tensor.numpy(), **kwargs)
except Exception:
return
kwargs = {}
if dtype is not None:
kwargs = {"dtype": getattr(tnp, dtype.__name__)}
result = tnp.add(scalar, array, **kwargs).tensor.numpy()
assert result.dtype == result_numpy.dtype
assert result == result_numpy
finally:
_np._set_promotion_state(state)
@parametrize("name", tnp._ufuncs._binary)
@parametrize("scalar, array", itertools.product(weaks, non_weaks))
def test_compare_ufuncs(self, name, scalar, array):
if name in corners and (
array.dtype.name in corners[name]
or tnp.asarray(scalar).dtype.name in corners[name]
):
raise SkipTest(f"{name}(..., dtype=array.dtype)")
try:
state = _np._get_promotion_state()
_np._set_promotion_state("weak")
if name in ["matmul", "modf", "divmod", "ldexp"]:
return
ufunc = getattr(tnp, name)
ufunc_numpy = getattr(_np, name)
try:
result = ufunc(scalar, array)
except RuntimeError:
# RuntimeError: "bitwise_xor_cpu" not implemented for 'ComplexDouble' etc
result = None
try:
result_numpy = ufunc_numpy(scalar, array.tensor.numpy())
except TypeError:
# TypeError: ufunc 'hypot' not supported for the input types
result_numpy = None
if result is not None and result_numpy is not None:
assert result.tensor.numpy().dtype == result_numpy.dtype
finally:
_np._set_promotion_state(state)
if __name__ == "__main__":
run_tests()