mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129758 Approved by: https://github.com/ezyang
221 lines
6.6 KiB
Python
221 lines
6.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
"""Test examples for NEP 50."""
|
|
|
|
import itertools
|
|
from unittest import skipIf as skipif, SkipTest
|
|
|
|
|
|
try:
|
|
import numpy as _np
|
|
|
|
v = _np.__version__.split(".")
|
|
HAVE_NUMPY = int(v[0]) >= 1 and int(v[1]) >= 24
|
|
except ImportError:
|
|
HAVE_NUMPY = False
|
|
|
|
import torch._numpy as tnp
|
|
from torch._numpy import ( # noqa: F401
|
|
array,
|
|
bool_,
|
|
complex128,
|
|
complex64,
|
|
float32,
|
|
float64,
|
|
inf,
|
|
int16,
|
|
int32,
|
|
int64,
|
|
uint8,
|
|
)
|
|
from torch._numpy.testing import assert_allclose
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
uint16 = uint8 # can be anything here, see below
|
|
|
|
|
|
# from numpy import array, uint8, uint16, int64, float32, float64, inf
|
|
# from numpy.testing import assert_allclose
|
|
# import numpy as np
|
|
# np._set_promotion_state('weak')
|
|
|
|
from pytest import raises as assert_raises
|
|
|
|
|
|
unchanged = None
|
|
|
|
# expression old result new_result
|
|
examples = {
|
|
"uint8(1) + 2": (int64(3), uint8(3)),
|
|
"array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)),
|
|
"array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)),
|
|
"array([1.], float32) + float64(1.)": (
|
|
array([2.0], float32),
|
|
array([2.0], float64),
|
|
),
|
|
"array([1.], float32) + array(1., float64)": (
|
|
array([2.0], float32),
|
|
array([2.0], float64),
|
|
),
|
|
"array([1], uint8) + 1": (array([2], uint8), unchanged),
|
|
"array([1], uint8) + 200": (array([201], uint8), unchanged),
|
|
"array([100], uint8) + 200": (array([44], uint8), unchanged),
|
|
"array([1], uint8) + 300": (array([301], uint16), Exception),
|
|
"uint8(1) + 300": (int64(301), Exception),
|
|
"uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning
|
|
"float32(1) + 3e100": (float64(3e100), float32(inf)), # and RuntimeWarning [T7]
|
|
"array([1.0], float32) + 1e-14 == 1.0": (array([True]), unchanged),
|
|
"array([0.1], float32) == float64(0.1)": (array([True]), array([False])),
|
|
"array(1.0, float32) + 1e-14 == 1.0": (array(False), array(True)),
|
|
"array([1.], float32) + 3": (array([4.0], float32), unchanged),
|
|
"array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)),
|
|
"3j + array(3, complex64)": (array(3 + 3j, complex128), array(3 + 3j, complex64)),
|
|
"float32(1) + 1j": (array(1 + 1j, complex128), array(1 + 1j, complex64)),
|
|
"int32(1) + 5j": (array(1 + 5j, complex128), unchanged),
|
|
# additional examples from the NEP text
|
|
"int16(2) + 2": (int64(4), int16(4)),
|
|
"int16(4) + 4j": (complex128(4 + 4j), unchanged),
|
|
"float32(5) + 5j": (complex128(5 + 5j), complex64(5 + 5j)),
|
|
"bool_(True) + 1": (int64(2), unchanged),
|
|
"True + uint8(2)": (uint8(3), unchanged),
|
|
}
|
|
|
|
|
|
@skipif(not HAVE_NUMPY, reason="NumPy not found")
|
|
@instantiate_parametrized_tests
|
|
class TestNEP50Table(TestCase):
|
|
@parametrize("example", examples)
|
|
def test_nep50_exceptions(self, example):
|
|
old, new = examples[example]
|
|
|
|
if new == Exception:
|
|
with assert_raises(OverflowError):
|
|
eval(example)
|
|
|
|
else:
|
|
result = eval(example)
|
|
|
|
if new is unchanged:
|
|
new = old
|
|
|
|
assert_allclose(result, new, atol=1e-16)
|
|
assert result.dtype == new.dtype
|
|
|
|
|
|
# ### Directly compare to numpy ###
|
|
|
|
weaks = (True, 1, 2.0, 3j)
|
|
non_weaks = (
|
|
tnp.asarray(True),
|
|
tnp.uint8(1),
|
|
tnp.int8(1),
|
|
tnp.int32(1),
|
|
tnp.int64(1),
|
|
tnp.float32(1),
|
|
tnp.float64(1),
|
|
tnp.complex64(1),
|
|
tnp.complex128(1),
|
|
)
|
|
if HAVE_NUMPY:
|
|
dtypes = (
|
|
None,
|
|
_np.bool_,
|
|
_np.uint8,
|
|
_np.int8,
|
|
_np.int32,
|
|
_np.int64,
|
|
_np.float32,
|
|
_np.float64,
|
|
_np.complex64,
|
|
_np.complex128,
|
|
)
|
|
else:
|
|
dtypes = (None,)
|
|
|
|
|
|
# ufunc name: [array.dtype]
|
|
corners = {
|
|
"true_divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
|
|
"divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
|
|
"arctan2": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
|
|
"copysign": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
|
|
"heaviside": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
|
|
"ldexp": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
|
|
"power": ["uint8"],
|
|
"nextafter": ["float32"],
|
|
}
|
|
|
|
|
|
@skipif(not HAVE_NUMPY, reason="NumPy not found")
|
|
@instantiate_parametrized_tests
|
|
class TestCompareToNumpy(TestCase):
|
|
@parametrize("scalar, array, dtype", itertools.product(weaks, non_weaks, dtypes))
|
|
def test_direct_compare(self, scalar, array, dtype):
|
|
# compare to NumPy w/ NEP 50.
|
|
try:
|
|
state = _np._get_promotion_state()
|
|
_np._set_promotion_state("weak")
|
|
|
|
if dtype is not None:
|
|
kwargs = {"dtype": dtype}
|
|
try:
|
|
result_numpy = _np.add(scalar, array.tensor.numpy(), **kwargs)
|
|
except Exception:
|
|
return
|
|
|
|
kwargs = {}
|
|
if dtype is not None:
|
|
kwargs = {"dtype": getattr(tnp, dtype.__name__)}
|
|
result = tnp.add(scalar, array, **kwargs).tensor.numpy()
|
|
assert result.dtype == result_numpy.dtype
|
|
assert result == result_numpy
|
|
|
|
finally:
|
|
_np._set_promotion_state(state)
|
|
|
|
@parametrize("name", tnp._ufuncs._binary)
|
|
@parametrize("scalar, array", itertools.product(weaks, non_weaks))
|
|
def test_compare_ufuncs(self, name, scalar, array):
|
|
if name in corners and (
|
|
array.dtype.name in corners[name]
|
|
or tnp.asarray(scalar).dtype.name in corners[name]
|
|
):
|
|
raise SkipTest(f"{name}(..., dtype=array.dtype)")
|
|
|
|
try:
|
|
state = _np._get_promotion_state()
|
|
_np._set_promotion_state("weak")
|
|
|
|
if name in ["matmul", "modf", "divmod", "ldexp"]:
|
|
return
|
|
ufunc = getattr(tnp, name)
|
|
ufunc_numpy = getattr(_np, name)
|
|
|
|
try:
|
|
result = ufunc(scalar, array)
|
|
except RuntimeError:
|
|
# RuntimeError: "bitwise_xor_cpu" not implemented for 'ComplexDouble' etc
|
|
result = None
|
|
|
|
try:
|
|
result_numpy = ufunc_numpy(scalar, array.tensor.numpy())
|
|
except TypeError:
|
|
# TypeError: ufunc 'hypot' not supported for the input types
|
|
result_numpy = None
|
|
|
|
if result is not None and result_numpy is not None:
|
|
assert result.tensor.numpy().dtype == result_numpy.dtype
|
|
|
|
finally:
|
|
_np._set_promotion_state(state)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|