mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
TST: pytorchify test/torch_np/test_dtype.py (#109967)
This file was missing from https://github.com/pytorch/pytorch/pull/109593 NB: This PR only mechanically converts the test. Will add more tests to see what's going on with `dtype=np.float64` etc under dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109967 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
95e2eec9bf
commit
befe60afc2
@ -1,11 +1,18 @@
|
|||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
import numpy as np
|
from unittest import expectedFailure as xfail
|
||||||
import pytest
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
import torch._numpy as tnp
|
import torch._numpy as tnp
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import (
|
||||||
|
instantiate_parametrized_tests,
|
||||||
|
parametrize,
|
||||||
|
run_tests,
|
||||||
|
subtest,
|
||||||
|
TestCase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
dtype_names = [
|
dtype_names = [
|
||||||
@ -15,45 +22,41 @@ dtype_names = [
|
|||||||
*[f"float{w}" for w in [16, 32, 64]],
|
*[f"float{w}" for w in [16, 32, 64]],
|
||||||
*[f"complex{w}" for w in [64, 128]],
|
*[f"complex{w}" for w in [64, 128]],
|
||||||
]
|
]
|
||||||
|
|
||||||
np_dtype_params = []
|
np_dtype_params = []
|
||||||
np_dtype_params.append(pytest.param("bool", "bool", id="'bool'"))
|
|
||||||
np_dtype_params.append(
|
np_dtype_params = [
|
||||||
pytest.param(
|
subtest(("bool", "bool"), name="bool"),
|
||||||
"bool",
|
subtest(
|
||||||
np.dtype("bool"),
|
("bool", numpy.dtype("bool")),
|
||||||
id="np.dtype('bool')",
|
name="numpy.dtype('bool')",
|
||||||
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
|
decorators=[xfail], # reason="XXX: np.dtype() objects not supported"),
|
||||||
)
|
),
|
||||||
)
|
]
|
||||||
|
|
||||||
|
|
||||||
for name in dtype_names:
|
for name in dtype_names:
|
||||||
np_dtype_params.append(pytest.param(name, name, id=repr(name)))
|
np_dtype_params.append(subtest((name, name), name=repr(name)))
|
||||||
|
|
||||||
np_dtype_params.append(
|
np_dtype_params.append(
|
||||||
pytest.param(
|
subtest((name, getattr(numpy, name)), name=f"numpy.{name}", decorators=[xfail])
|
||||||
name,
|
) # numpy namespaced dtypes not supported
|
||||||
getattr(np, name),
|
|
||||||
id=f"np.{name}",
|
|
||||||
marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
np_dtype_params.append(
|
np_dtype_params.append(
|
||||||
pytest.param(
|
subtest((name, numpy.dtype(name)), name=f"numpy.{name!r}", decorators=[xfail])
|
||||||
name,
|
|
||||||
np.dtype(name),
|
|
||||||
id=f"np.dtype({name!r})",
|
|
||||||
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name, np_dtype", np_dtype_params)
|
@instantiate_parametrized_tests
|
||||||
def test_convert_np_dtypes(name, np_dtype):
|
class TestConvertDType(TestCase):
|
||||||
tnp_dtype = tnp.dtype(np_dtype)
|
@parametrize("name, np_dtype", np_dtype_params)
|
||||||
if name == "bool_":
|
def test_convert_np_dtypes(self, name, np_dtype):
|
||||||
assert tnp_dtype == tnp.bool_
|
tnp_dtype = tnp.dtype(np_dtype)
|
||||||
elif tnp_dtype.name == "bool_":
|
if name == "bool_":
|
||||||
assert name.startswith("bool")
|
assert tnp_dtype == tnp.bool_
|
||||||
else:
|
elif tnp_dtype.name == "bool_":
|
||||||
assert tnp_dtype.name == name
|
assert name.startswith("bool")
|
||||||
|
else:
|
||||||
|
assert tnp_dtype.name == name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Reference in New Issue
Block a user