mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
TST: pytorchify test/torch_np/test_dtype.py (#109967)
This file was missing from https://github.com/pytorch/pytorch/pull/109593 NB: This PR only mechanically converts the test. Will add more tests to see what's going on with `dtype=np.float64` etc under dynamo. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109967 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
95e2eec9bf
commit
befe60afc2
@ -1,11 +1,18 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest import expectedFailure as xfail
|
||||
|
||||
import numpy
|
||||
|
||||
import torch._numpy as tnp
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
dtype_names = [
|
||||
@ -15,45 +22,41 @@ dtype_names = [
|
||||
*[f"float{w}" for w in [16, 32, 64]],
|
||||
*[f"complex{w}" for w in [64, 128]],
|
||||
]
|
||||
|
||||
np_dtype_params = []
|
||||
np_dtype_params.append(pytest.param("bool", "bool", id="'bool'"))
|
||||
np_dtype_params.append(
|
||||
pytest.param(
|
||||
"bool",
|
||||
np.dtype("bool"),
|
||||
id="np.dtype('bool')",
|
||||
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
|
||||
)
|
||||
)
|
||||
|
||||
np_dtype_params = [
|
||||
subtest(("bool", "bool"), name="bool"),
|
||||
subtest(
|
||||
("bool", numpy.dtype("bool")),
|
||||
name="numpy.dtype('bool')",
|
||||
decorators=[xfail], # reason="XXX: np.dtype() objects not supported"),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
for name in dtype_names:
|
||||
np_dtype_params.append(pytest.param(name, name, id=repr(name)))
|
||||
np_dtype_params.append(subtest((name, name), name=repr(name)))
|
||||
|
||||
np_dtype_params.append(
|
||||
pytest.param(
|
||||
name,
|
||||
getattr(np, name),
|
||||
id=f"np.{name}",
|
||||
marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"),
|
||||
)
|
||||
)
|
||||
subtest((name, getattr(numpy, name)), name=f"numpy.{name}", decorators=[xfail])
|
||||
) # numpy namespaced dtypes not supported
|
||||
np_dtype_params.append(
|
||||
pytest.param(
|
||||
name,
|
||||
np.dtype(name),
|
||||
id=f"np.dtype({name!r})",
|
||||
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
|
||||
)
|
||||
subtest((name, numpy.dtype(name)), name=f"numpy.{name!r}", decorators=[xfail])
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name, np_dtype", np_dtype_params)
|
||||
def test_convert_np_dtypes(name, np_dtype):
|
||||
tnp_dtype = tnp.dtype(np_dtype)
|
||||
if name == "bool_":
|
||||
assert tnp_dtype == tnp.bool_
|
||||
elif tnp_dtype.name == "bool_":
|
||||
assert name.startswith("bool")
|
||||
else:
|
||||
assert tnp_dtype.name == name
|
||||
@instantiate_parametrized_tests
|
||||
class TestConvertDType(TestCase):
|
||||
@parametrize("name, np_dtype", np_dtype_params)
|
||||
def test_convert_np_dtypes(self, name, np_dtype):
|
||||
tnp_dtype = tnp.dtype(np_dtype)
|
||||
if name == "bool_":
|
||||
assert tnp_dtype == tnp.bool_
|
||||
elif tnp_dtype.name == "bool_":
|
||||
assert name.startswith("bool")
|
||||
else:
|
||||
assert tnp_dtype.name == name
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user