TST: pytorchify test/torch_np/test_dtype.py (#109967)

This file was missing from https://github.com/pytorch/pytorch/pull/109593

NB: This PR only mechanically converts the test. Will add more tests to see what's going on with `dtype=np.float64` etc under dynamo.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109967
Approved by: https://github.com/lezcano
This commit is contained in:
Evgeni Burovski
2023-09-24 13:34:02 +00:00
committed by PyTorch MergeBot
parent 95e2eec9bf
commit befe60afc2

View File

@ -1,11 +1,18 @@
# Owner(s): ["module: dynamo"]
import numpy as np
import pytest
from unittest import expectedFailure as xfail
import numpy
import torch._numpy as tnp
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TestCase,
)
dtype_names = [
@ -15,45 +22,41 @@ dtype_names = [
*[f"float{w}" for w in [16, 32, 64]],
*[f"complex{w}" for w in [64, 128]],
]
np_dtype_params = []
np_dtype_params.append(pytest.param("bool", "bool", id="'bool'"))
np_dtype_params.append(
pytest.param(
"bool",
np.dtype("bool"),
id="np.dtype('bool')",
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
)
)
np_dtype_params = [
subtest(("bool", "bool"), name="bool"),
subtest(
("bool", numpy.dtype("bool")),
name="numpy.dtype('bool')",
decorators=[xfail], # reason="XXX: np.dtype() objects not supported"),
),
]
for name in dtype_names:
np_dtype_params.append(pytest.param(name, name, id=repr(name)))
np_dtype_params.append(subtest((name, name), name=repr(name)))
np_dtype_params.append(
pytest.param(
name,
getattr(np, name),
id=f"np.{name}",
marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"),
)
)
subtest((name, getattr(numpy, name)), name=f"numpy.{name}", decorators=[xfail])
) # numpy namespaced dtypes not supported
np_dtype_params.append(
pytest.param(
name,
np.dtype(name),
id=f"np.dtype({name!r})",
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
)
subtest((name, numpy.dtype(name)), name=f"numpy.{name!r}", decorators=[xfail])
)
@pytest.mark.parametrize("name, np_dtype", np_dtype_params)
def test_convert_np_dtypes(name, np_dtype):
tnp_dtype = tnp.dtype(np_dtype)
if name == "bool_":
assert tnp_dtype == tnp.bool_
elif tnp_dtype.name == "bool_":
assert name.startswith("bool")
else:
assert tnp_dtype.name == name
@instantiate_parametrized_tests
class TestConvertDType(TestCase):
@parametrize("name, np_dtype", np_dtype_params)
def test_convert_np_dtypes(self, name, np_dtype):
tnp_dtype = tnp.dtype(np_dtype)
if name == "bool_":
assert tnp_dtype == tnp.bool_
elif tnp_dtype.name == "bool_":
assert name.startswith("bool")
else:
assert tnp_dtype.name == name
if __name__ == "__main__":