Compare commits

...

47 Commits

Author SHA1 Message Date
e8a718a7e4 Update
[ghstack-poisoned]
2024-06-04 17:56:39 +00:00
591bc43d0d Update
[ghstack-poisoned]
2024-06-04 13:28:10 +00:00
3fc6dab449 Update (base update)
[ghstack-poisoned]
2024-06-04 13:28:10 +00:00
6bd96079f3 Update
[ghstack-poisoned]
2024-05-31 13:12:02 +00:00
20a9697b09 Update (base update)
[ghstack-poisoned]
2024-05-31 13:12:02 +00:00
452d20c56a Update
[ghstack-poisoned]
2024-05-28 17:56:42 +00:00
8147c5ae73 Update (base update)
[ghstack-poisoned]
2024-05-28 17:56:42 +00:00
f92d37e6de Update
[ghstack-poisoned]
2024-05-28 15:42:30 +00:00
42198e7aaa Update (base update)
[ghstack-poisoned]
2024-05-28 15:42:30 +00:00
e6b868fbf9 Update
[ghstack-poisoned]
2024-05-14 22:35:40 +00:00
38b220eb18 Update (base update)
[ghstack-poisoned]
2024-05-14 22:35:40 +00:00
5d586eebd3 Update
[ghstack-poisoned]
2024-05-13 18:39:33 +00:00
6d347bbfe9 Update (base update)
[ghstack-poisoned]
2024-05-13 18:39:33 +00:00
ca64f38a59 Update
[ghstack-poisoned]
2024-05-08 15:28:42 +00:00
0edec418e0 Update (base update)
[ghstack-poisoned]
2024-05-08 15:28:42 +00:00
bd69e1cd4a Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-07 22:32:31 +00:00
90ae1c0491 Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-07 22:32:31 +00:00
133c608201 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-07 22:23:22 +00:00
68eef4cbb6 Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-07 22:23:22 +00:00
0f14d35ddc Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-03 14:10:00 +00:00
f7d6eb01ce Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-03 14:10:00 +00:00
f109219fd7 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-02 23:35:33 +00:00
0dd27104b8 Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-02 23:35:33 +00:00
dc9580b0c9 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-02 19:35:25 +00:00
09dbeb7ecd Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-02 19:35:25 +00:00
f7eb123b20 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-02 13:22:53 +00:00
5844251957 Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-02 13:22:53 +00:00
a25f42aeb2 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-01 21:13:27 +00:00
6475c24efa Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-01 21:13:27 +00:00
48f51d3afa Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-01 19:27:55 +00:00
ec773a663c Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-01 19:27:55 +00:00
0b0a88e4ea Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-01 18:27:57 +00:00
1d47fb54cc Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-05-01 18:27:57 +00:00
cda63f9980 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-24 15:43:06 +00:00
af30f8e4a3 Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-24 15:43:06 +00:00
b516b013d9 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-24 15:37:32 +00:00
55f2145f03 Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-24 15:37:32 +00:00
58ac759f97 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-24 15:23:26 +00:00
7a9b644baa Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-24 15:23:26 +00:00
7a18fd2249 Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-19 16:10:56 +00:00
e5619f7012 Update base for Update on "[dynamo] Support ndarray.dtype attribute access"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-19 16:10:56 +00:00
0d04859ed7 [dynamo] Support ndarray.dtype attribute access
[ghstack-poisoned]
2024-04-19 15:33:03 +00:00
e7ca7712ed Update on "[dynamo] Handle np.iinfo/finfo/dtype as input"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-19 15:20:21 +00:00
9f60bd5753 Update base for Update on "[dynamo] Handle np.iinfo/finfo/dtype as input"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-19 15:20:21 +00:00
55644249a8 Update on "[dynamo] Handle np.iinfo/finfo/dtype as input"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
2024-04-19 14:24:39 +00:00
f7ede74d24 [dynamo] Handle np.iinfo/finfo/dtype as input
[ghstack-poisoned]
2024-04-19 14:21:03 +00:00
cfeeb15a75 [dynamo] Support numpy.dtype
[ghstack-poisoned]
2024-04-19 14:20:56 +00:00
15 changed files with 582 additions and 40 deletions

View File

@ -1619,6 +1619,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
dt = np.dtype("float")
return np.full_like(x, 2.4, dtype=dt)
@make_test
def test_numpy_dtype_attr(x):
return np.ones_like(x).dtype == x.dtype
@make_test
def test_numpy_linalg(x):
return np.linalg.norm(x.numpy(), axis=0)

View File

@ -121,9 +121,7 @@ class TestBinaryUfuncs(TestCase):
def _helper_reference_numerics(
expected, actual, msg, exact_dtype, equal_nan=True
):
if not torch.can_cast(
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
):
if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype], dtype):
exact_dtype = False
if dtype is torch.bfloat16 and expected.dtype == np.float32:

View File

@ -476,13 +476,18 @@ class TestNumPyInterop(TestCase):
self.assertTrue(r2.requires_grad)
@onlyCPU
def test_parse_numpy_int(self, device):
@skipIfTorchDynamo()
def test_parse_numpy_int_overflow(self, device):
# assertRaises uses a try-except which dynamo has issues with
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
self.assertRaisesRegex(
RuntimeError,
"(Overflow|an integer is required)",
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
) # type: ignore[call-overload]
@onlyCPU
def test_parse_numpy_int(self, device):
# https://github.com/pytorch/pytorch/issues/29252
for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
scalar = 3

View File

@ -184,7 +184,7 @@ class TestUnaryUfuncs(TestCase):
expected, actual, msg, exact_dtype, equal_nan=True
):
if not torch.can_cast(
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
numpy_to_torch_dtype_dict[expected.dtype], dtype
):
exact_dtype = False

View File

@ -8,14 +8,17 @@ import warnings
# from numpy.core.getlimits import _discovered_machar, _float_ma
from unittest import skipIf
from unittest import expectedFailure as xfail, skipIf
import numpy
from pytest import raises as assert_raises
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
xpassIfTorchDynamo,
@ -109,6 +112,7 @@ class TestFinfo(TestCase):
getattr(finfo(dt), attr)
@instantiate_parametrized_tests
class TestIinfo(TestCase):
def test_basic(self):
dts = list(
@ -129,11 +133,19 @@ class TestIinfo(TestCase):
with assert_raises((TypeError, ValueError)):
iinfo("f4")
def test_unsigned_max(self):
types = np.sctypes["uint"]
for T in types:
max_calculated = T(0) - T(1)
assert_equal(iinfo(T).max, max_calculated)
@parametrize(
"T",
[
np.uint8,
# xfail: unsupported add (uint[16,32,64])
subtest(np.uint16, decorators=[xfail]),
subtest(np.uint32, decorators=[xfail]),
subtest(np.uint64, decorators=[xfail]),
],
)
def test_unsigned_max(self, T):
max_calculated = T(0) - T(1)
assert_equal(iinfo(T).max, max_calculated)
class TestRepr(TestCase):

View File

@ -1833,7 +1833,7 @@ class TestMethods(TestCase):
a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.unicode_)
assert_equal(a.argsort(kind="m"), r)
@xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch")
@xfail # (reason="TODO: searchsorted with nans differs in pytorch")
@parametrize(
"a",
[
@ -1905,7 +1905,7 @@ class TestMethods(TestCase):
b = a.searchsorted([0, 1, 2], "right")
assert_equal(b, [0, 2, 2])
@xpassIfTorchDynamo # (
@xfail # (
# reason="RuntimeError: self.storage_offset() must be divisible by 8"
# )
def test_searchsorted_unaligned_array(self):
@ -1984,7 +1984,7 @@ class TestMethods(TestCase):
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1, 0, 1, 2, 3])
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4, 0, -1, 2, 3])
@xpassIfTorchDynamo # (reason="self.storage_offset() must be divisible by 8")
@xfail # (reason="self.storage_offset() must be divisible by 8")
def test_searchsorted_with_sorter(self):
a = np.random.rand(300)
s = a.argsort()
@ -3713,7 +3713,14 @@ class TestTake(TestCase):
y = np.take(x, [1, 2, 3], out=x[2:5], mode="wrap")
assert_equal(y, np.array([1, 2, 3]))
@parametrize("shape", [(1, 2), (1,), ()])
@parametrize(
"shape",
[
subtest((1, 2)),
subtest((1,)),
subtest((), decorators=[skip("Sensitive to np version")]),
],
)
def test_ret_is_out(self, shape):
# 0d arrays should not be an exception to this rule
x = np.arange(5)

View File

@ -732,13 +732,16 @@ class TestAbs(TestCase):
@instantiate_parametrized_tests
class TestBitShifts(TestCase):
@parametrize("type_code", np.typecodes["Integer"] + "B")
@parametrize("type_code", np.typecodes["AllInteger"])
@parametrize("op", [operator.rshift, operator.lshift])
def test_shift_all_bits(self, type_code, op):
"""Shifts where the shift amount is the width of the type or wider"""
# gh-2449
dt = np.dtype(type_code)
nbits = dt.itemsize * 8
if dt in (np.dtype(np.uint64), np.dtype(np.uint32), np.dtype(np.uint16)):
raise SkipTest("NYI: bitshift uint64")
for val in [5, -5]:
for shift in [nbits, nbits + 4]:
val_scl = np.array(val).astype(dt)[()]

View File

@ -18,7 +18,7 @@ from torch.testing._internal.common_utils import (
dtype_names = [
"bool_",
*[f"int{w}" for w in [8, 16, 32, 64]],
"uint8",
*[f"uint{w}" for w in [8, 16, 32, 64]],
*[f"float{w}" for w in [16, 32, 64]],
*[f"complex{w}" for w in [64, 128]],
]

View File

@ -1189,6 +1189,11 @@ class NumpyTypeInfoVariable(ConstantLikeVariable):
class NumpyDTypeVariable(ConstantLikeVariable):
_error_prefix = "np.dtype[...]"
def __init__(self, value, **kwargs):
if isinstance(value, tnp.DType):
value = ConstantLikeVariable.np_dtype(value.name)
super().__init__(value, **kwargs)
def as_proxy(self):
"""Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:

View File

@ -1089,6 +1089,7 @@ class NumpyNdarrayVariable(TensorVariable):
from ..utils import numpy_attr_wrapper
from .builder import wrap_fx_proxy
from .misc import NumpyDTypeVariable
result = None
@ -1135,6 +1136,8 @@ class NumpyNdarrayVariable(TensorVariable):
if not has_free_symbols(r := example_ndarray.size):
return ConstantVariable.create(int(r))
return insert_into_graph()
if name == "dtype":
return NumpyDTypeVariable(example_ndarray.dtype)
elif name in ["base", "flags", "dtype"]:
unimplemented(f"TODO: add support for ndarray.{name}")
elif name in ["__version__"]:

View File

@ -3,7 +3,7 @@
import torch
# These two dicts are autogenerated with autogen/gen_dtypes.py,
# using numpy version 1.23.5.
# using numpy version 1.24.3.
_can_cast_dict = {
"no": {
@ -14,6 +14,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -27,6 +30,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -40,6 +46,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -53,6 +62,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -66,6 +78,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -79,6 +94,57 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: True,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: False,
torch.bool: False,
},
torch.uint16: {
torch.float16: False,
torch.float32: False,
torch.float64: False,
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: True,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: False,
torch.bool: False,
},
torch.uint32: {
torch.float16: False,
torch.float32: False,
torch.float64: False,
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: True,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: False,
torch.bool: False,
},
torch.uint64: {
torch.float16: False,
torch.float32: False,
torch.float64: False,
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: True,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -92,6 +158,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: True,
torch.int16: False,
torch.int32: False,
@ -105,6 +174,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: True,
torch.int32: False,
@ -118,6 +190,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: True,
@ -131,6 +206,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -144,6 +222,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -159,6 +240,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -172,6 +256,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -185,6 +272,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -198,6 +288,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -211,6 +304,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -224,6 +320,57 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: True,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: False,
torch.bool: False,
},
torch.uint16: {
torch.float16: False,
torch.float32: False,
torch.float64: False,
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: True,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: False,
torch.bool: False,
},
torch.uint32: {
torch.float16: False,
torch.float32: False,
torch.float64: False,
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: True,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: False,
torch.bool: False,
},
torch.uint64: {
torch.float16: False,
torch.float32: False,
torch.float64: False,
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: True,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -237,6 +384,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: True,
torch.int16: False,
torch.int32: False,
@ -250,6 +400,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: True,
torch.int32: False,
@ -263,6 +416,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: True,
@ -276,6 +432,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -289,6 +448,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -304,6 +466,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -317,6 +482,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -330,6 +498,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -343,6 +514,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -356,6 +530,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -369,12 +546,63 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: False,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: False,
},
torch.uint16: {
torch.float16: False,
torch.float32: True,
torch.float64: True,
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: False,
torch.int16: False,
torch.int32: True,
torch.int64: True,
torch.bool: False,
},
torch.uint32: {
torch.float16: False,
torch.float32: False,
torch.float64: True,
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: True,
torch.uint64: True,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: True,
torch.bool: False,
},
torch.uint64: {
torch.float16: False,
torch.float32: False,
torch.float64: True,
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: True,
torch.int8: False,
torch.int16: False,
torch.int32: False,
torch.int64: False,
torch.bool: False,
},
torch.int8: {
torch.float16: True,
torch.float32: True,
@ -382,6 +610,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -395,6 +626,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: True,
torch.int32: True,
@ -408,6 +642,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: True,
@ -421,6 +658,9 @@ _can_cast_dict = {
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -434,6 +674,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -449,6 +692,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -462,6 +708,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -475,6 +724,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -488,6 +740,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -501,6 +756,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@ -514,6 +772,57 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: False,
},
torch.uint16: {
torch.float16: True,
torch.float32: True,
torch.float64: True,
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: False,
},
torch.uint32: {
torch.float16: True,
torch.float32: True,
torch.float64: True,
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: False,
},
torch.uint64: {
torch.float16: True,
torch.float32: True,
torch.float64: True,
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -527,6 +836,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -540,6 +852,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -553,6 +868,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -566,6 +884,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
torch.uint16: False,
torch.uint32: False,
torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -579,6 +900,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -594,6 +918,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -607,6 +934,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -620,6 +950,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -633,6 +966,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -646,6 +982,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -659,6 +998,57 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: True,
},
torch.uint16: {
torch.float16: True,
torch.float32: True,
torch.float64: True,
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: True,
},
torch.uint32: {
torch.float16: True,
torch.float32: True,
torch.float64: True,
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: True,
},
torch.uint64: {
torch.float16: True,
torch.float32: True,
torch.float64: True,
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -672,6 +1062,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -685,6 +1078,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -698,6 +1094,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -711,6 +1110,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -724,6 +1126,9 @@ _can_cast_dict = {
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
torch.uint16: True,
torch.uint32: True,
torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@ -742,6 +1147,9 @@ _result_type_dict = {
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.float16,
torch.uint16: torch.float32,
torch.uint32: torch.float64,
torch.uint64: torch.float64,
torch.int8: torch.float16,
torch.int16: torch.float32,
torch.int32: torch.float64,
@ -755,6 +1163,9 @@ _result_type_dict = {
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.float32,
torch.uint16: torch.float32,
torch.uint32: torch.float64,
torch.uint64: torch.float64,
torch.int8: torch.float32,
torch.int16: torch.float32,
torch.int32: torch.float64,
@ -768,6 +1179,9 @@ _result_type_dict = {
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.float64,
torch.uint16: torch.float64,
torch.uint32: torch.float64,
torch.uint64: torch.float64,
torch.int8: torch.float64,
torch.int16: torch.float64,
torch.int32: torch.float64,
@ -781,6 +1195,9 @@ _result_type_dict = {
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.complex64,
torch.uint16: torch.complex64,
torch.uint32: torch.complex128,
torch.uint64: torch.complex128,
torch.int8: torch.complex64,
torch.int16: torch.complex64,
torch.int32: torch.complex128,
@ -794,6 +1211,9 @@ _result_type_dict = {
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.complex128,
torch.uint16: torch.complex128,
torch.uint32: torch.complex128,
torch.uint64: torch.complex128,
torch.int8: torch.complex128,
torch.int16: torch.complex128,
torch.int32: torch.complex128,
@ -807,12 +1227,63 @@ _result_type_dict = {
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.uint8,
torch.uint16: torch.uint16,
torch.uint32: torch.uint32,
torch.uint64: torch.uint64,
torch.int8: torch.int16,
torch.int16: torch.int16,
torch.int32: torch.int32,
torch.int64: torch.int64,
torch.bool: torch.uint8,
},
torch.uint16: {
torch.float16: torch.float32,
torch.float32: torch.float32,
torch.float64: torch.float64,
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.uint16,
torch.uint16: torch.uint16,
torch.uint32: torch.uint32,
torch.uint64: torch.uint64,
torch.int8: torch.int32,
torch.int16: torch.int32,
torch.int32: torch.int32,
torch.int64: torch.int64,
torch.bool: torch.uint16,
},
torch.uint32: {
torch.float16: torch.float64,
torch.float32: torch.float64,
torch.float64: torch.float64,
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.uint32,
torch.uint16: torch.uint32,
torch.uint32: torch.uint32,
torch.uint64: torch.uint64,
torch.int8: torch.int64,
torch.int16: torch.int64,
torch.int32: torch.int64,
torch.int64: torch.int64,
torch.bool: torch.uint32,
},
torch.uint64: {
torch.float16: torch.float64,
torch.float32: torch.float64,
torch.float64: torch.float64,
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.uint64,
torch.uint16: torch.uint64,
torch.uint32: torch.uint64,
torch.uint64: torch.uint64,
torch.int8: torch.float64,
torch.int16: torch.float64,
torch.int32: torch.float64,
torch.int64: torch.float64,
torch.bool: torch.uint64,
},
torch.int8: {
torch.float16: torch.float16,
torch.float32: torch.float32,
@ -820,6 +1291,9 @@ _result_type_dict = {
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.int16,
torch.uint16: torch.int32,
torch.uint32: torch.int64,
torch.uint64: torch.float64,
torch.int8: torch.int8,
torch.int16: torch.int16,
torch.int32: torch.int32,
@ -833,6 +1307,9 @@ _result_type_dict = {
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.int16,
torch.uint16: torch.int32,
torch.uint32: torch.int64,
torch.uint64: torch.float64,
torch.int8: torch.int16,
torch.int16: torch.int16,
torch.int32: torch.int32,
@ -846,6 +1323,9 @@ _result_type_dict = {
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.int32,
torch.uint16: torch.int32,
torch.uint32: torch.int64,
torch.uint64: torch.float64,
torch.int8: torch.int32,
torch.int16: torch.int32,
torch.int32: torch.int32,
@ -859,6 +1339,9 @@ _result_type_dict = {
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.int64,
torch.uint16: torch.int64,
torch.uint32: torch.int64,
torch.uint64: torch.float64,
torch.int8: torch.int64,
torch.int16: torch.int64,
torch.int32: torch.int64,
@ -872,6 +1355,9 @@ _result_type_dict = {
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.uint8,
torch.uint16: torch.uint16,
torch.uint32: torch.uint32,
torch.uint64: torch.uint64,
torch.int8: torch.int8,
torch.int16: torch.int16,
torch.int32: torch.int32,

View File

@ -113,6 +113,24 @@ class uint8(unsignedinteger):
torch_dtype = torch.uint8
class uint16(unsignedinteger):
name = "uint16"
typecode = "H"
torch_dtype = torch.uint16
class uint32(signedinteger):
name = "uint32"
typecode = "I"
torch_dtype = torch.uint32
class uint64(signedinteger):
name = "uint64"
typecode = "L"
torch_dtype = torch.uint64
# floating point
@ -160,6 +178,7 @@ _name_aliases = {
"byte": int8,
"short": int16,
"longlong": int64, # XXX: is this correct?
"ulonglong": uint64,
"ubyte": uint8,
"half": float16,
"single": float32,
@ -180,7 +199,7 @@ for name, obj in _name_aliases.items():
# cf tests/core/test_scalar_methods.py
sctypes = {
"int": [int8, int16, int32, int64],
"uint": [uint8],
"uint": [uint8, uint16, uint32, uint64],
"float": [float16, float32, float64],
"complex": [complex64, complex128],
"others": [bool_],

View File

@ -1500,31 +1500,31 @@ TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
numpy_to_torch_dtype_dict = {
np.bool_ : torch.bool,
np.uint8 : torch.uint8,
np.uint16 : torch.uint16,
np.uint32 : torch.uint32,
np.uint64 : torch.uint64,
np.int8 : torch.int8,
np.int16 : torch.int16,
np.int32 : torch.int32,
np.int64 : torch.int64,
np.float16 : torch.float16,
np.float32 : torch.float32,
np.float64 : torch.float64,
np.complex64 : torch.complex64,
np.complex128 : torch.complex128
np.dtype(np.bool_) : torch.bool,
np.dtype(np.uint8) : torch.uint8,
np.dtype(np.uint16) : torch.uint16,
np.dtype(np.uint32) : torch.uint32,
np.dtype(np.uint64) : torch.uint64,
np.dtype(np.int8) : torch.int8,
np.dtype(np.int16) : torch.int16,
np.dtype(np.int32) : torch.int32,
np.dtype(np.int64) : torch.int64,
np.dtype(np.float16) : torch.float16,
np.dtype(np.float32) : torch.float32,
np.dtype(np.float64) : torch.float64,
np.dtype(np.complex64) : torch.complex64,
np.dtype(np.complex128): torch.complex128
}
# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
# numpy dtypes like np.float64 are not instances, but rather classes. This leads
# to rather absurd cases like np.float64 != np.dtype("float64") but
# np.dtype(np.float64) == np.dtype("float64") and
# np.dtype(np.dtype("float64")) == np.dtype("float64"). Especially when
# checking against a reference we can't be sure which variant we get, so we
# simply apply the conversion.
def numpy_to_torch_dtype(np_dtype):
try:
return numpy_to_torch_dtype_dict[np_dtype]
except KeyError:
return numpy_to_torch_dtype_dict[np_dtype.type]
return numpy_to_torch_dtype_dict[np.dtype(np_dtype)]
def has_corresponding_torch_dtype(np_dtype):