TST: pytorchify test/torch_np/test_dtype.py (#109967)

This file was missing from https://github.com/pytorch/pytorch/pull/109593

NB: This PR only mechanically converts the test. Will add more tests to see what's going on with `dtype=np.float64` etc under dynamo.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109967
Approved by: https://github.com/lezcano
This commit is contained in:
Evgeni Burovski
2023-09-24 13:34:02 +00:00
committed by PyTorch MergeBot
parent 95e2eec9bf
commit befe60afc2

View File

@ -1,11 +1,18 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
import numpy as np from unittest import expectedFailure as xfail
import pytest
import numpy
import torch._numpy as tnp import torch._numpy as tnp
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TestCase,
)
dtype_names = [ dtype_names = [
@ -15,45 +22,41 @@ dtype_names = [
*[f"float{w}" for w in [16, 32, 64]], *[f"float{w}" for w in [16, 32, 64]],
*[f"complex{w}" for w in [64, 128]], *[f"complex{w}" for w in [64, 128]],
] ]
np_dtype_params = [] np_dtype_params = []
np_dtype_params.append(pytest.param("bool", "bool", id="'bool'"))
np_dtype_params.append( np_dtype_params = [
pytest.param( subtest(("bool", "bool"), name="bool"),
"bool", subtest(
np.dtype("bool"), ("bool", numpy.dtype("bool")),
id="np.dtype('bool')", name="numpy.dtype('bool')",
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"), decorators=[xfail], # reason="XXX: np.dtype() objects not supported"),
) ),
) ]
for name in dtype_names: for name in dtype_names:
np_dtype_params.append(pytest.param(name, name, id=repr(name))) np_dtype_params.append(subtest((name, name), name=repr(name)))
np_dtype_params.append( np_dtype_params.append(
pytest.param( subtest((name, getattr(numpy, name)), name=f"numpy.{name}", decorators=[xfail])
name, ) # numpy namespaced dtypes not supported
getattr(np, name),
id=f"np.{name}",
marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"),
)
)
np_dtype_params.append( np_dtype_params.append(
pytest.param( subtest((name, numpy.dtype(name)), name=f"numpy.{name!r}", decorators=[xfail])
name,
np.dtype(name),
id=f"np.dtype({name!r})",
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
)
) )
@pytest.mark.parametrize("name, np_dtype", np_dtype_params) @instantiate_parametrized_tests
def test_convert_np_dtypes(name, np_dtype): class TestConvertDType(TestCase):
tnp_dtype = tnp.dtype(np_dtype) @parametrize("name, np_dtype", np_dtype_params)
if name == "bool_": def test_convert_np_dtypes(self, name, np_dtype):
assert tnp_dtype == tnp.bool_ tnp_dtype = tnp.dtype(np_dtype)
elif tnp_dtype.name == "bool_": if name == "bool_":
assert name.startswith("bool") assert tnp_dtype == tnp.bool_
else: elif tnp_dtype.name == "bool_":
assert tnp_dtype.name == name assert name.startswith("bool")
else:
assert tnp_dtype.name == name
if __name__ == "__main__": if __name__ == "__main__":