mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Compare commits
47 Commits
v2.6.0-rc1
...
gh/amjames
| Author | SHA1 | Date | |
|---|---|---|---|
| e8a718a7e4 | |||
| 591bc43d0d | |||
| 3fc6dab449 | |||
| 6bd96079f3 | |||
| 20a9697b09 | |||
| 452d20c56a | |||
| 8147c5ae73 | |||
| f92d37e6de | |||
| 42198e7aaa | |||
| e6b868fbf9 | |||
| 38b220eb18 | |||
| 5d586eebd3 | |||
| 6d347bbfe9 | |||
| ca64f38a59 | |||
| 0edec418e0 | |||
| bd69e1cd4a | |||
| 90ae1c0491 | |||
| 133c608201 | |||
| 68eef4cbb6 | |||
| 0f14d35ddc | |||
| f7d6eb01ce | |||
| f109219fd7 | |||
| 0dd27104b8 | |||
| dc9580b0c9 | |||
| 09dbeb7ecd | |||
| f7eb123b20 | |||
| 5844251957 | |||
| a25f42aeb2 | |||
| 6475c24efa | |||
| 48f51d3afa | |||
| ec773a663c | |||
| 0b0a88e4ea | |||
| 1d47fb54cc | |||
| cda63f9980 | |||
| af30f8e4a3 | |||
| b516b013d9 | |||
| 55f2145f03 | |||
| 58ac759f97 | |||
| 7a9b644baa | |||
| 7a18fd2249 | |||
| e5619f7012 | |||
| 0d04859ed7 | |||
| e7ca7712ed | |||
| 9f60bd5753 | |||
| 55644249a8 | |||
| f7ede74d24 | |||
| cfeeb15a75 |
@ -1619,6 +1619,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
dt = np.dtype("float")
|
||||
return np.full_like(x, 2.4, dtype=dt)
|
||||
|
||||
@make_test
|
||||
def test_numpy_dtype_attr(x):
|
||||
return np.ones_like(x).dtype == x.dtype
|
||||
|
||||
@make_test
|
||||
def test_numpy_linalg(x):
|
||||
return np.linalg.norm(x.numpy(), axis=0)
|
||||
|
||||
@ -121,9 +121,7 @@ class TestBinaryUfuncs(TestCase):
|
||||
def _helper_reference_numerics(
|
||||
expected, actual, msg, exact_dtype, equal_nan=True
|
||||
):
|
||||
if not torch.can_cast(
|
||||
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
|
||||
):
|
||||
if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype], dtype):
|
||||
exact_dtype = False
|
||||
|
||||
if dtype is torch.bfloat16 and expected.dtype == np.float32:
|
||||
|
||||
@ -476,13 +476,18 @@ class TestNumPyInterop(TestCase):
|
||||
self.assertTrue(r2.requires_grad)
|
||||
|
||||
@onlyCPU
|
||||
def test_parse_numpy_int(self, device):
|
||||
@skipIfTorchDynamo()
|
||||
def test_parse_numpy_int_overflow(self, device):
|
||||
# assertRaises uses a try-except which dynamo has issues with
|
||||
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"(Overflow|an integer is required)",
|
||||
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
|
||||
) # type: ignore[call-overload]
|
||||
|
||||
@onlyCPU
|
||||
def test_parse_numpy_int(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/29252
|
||||
for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
|
||||
scalar = 3
|
||||
|
||||
@ -184,7 +184,7 @@ class TestUnaryUfuncs(TestCase):
|
||||
expected, actual, msg, exact_dtype, equal_nan=True
|
||||
):
|
||||
if not torch.can_cast(
|
||||
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
|
||||
numpy_to_torch_dtype_dict[expected.dtype], dtype
|
||||
):
|
||||
exact_dtype = False
|
||||
|
||||
|
||||
@ -8,14 +8,17 @@ import warnings
|
||||
|
||||
# from numpy.core.getlimits import _discovered_machar, _float_ma
|
||||
|
||||
from unittest import skipIf
|
||||
from unittest import expectedFailure as xfail, skipIf
|
||||
|
||||
import numpy
|
||||
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
xpassIfTorchDynamo,
|
||||
@ -109,6 +112,7 @@ class TestFinfo(TestCase):
|
||||
getattr(finfo(dt), attr)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestIinfo(TestCase):
|
||||
def test_basic(self):
|
||||
dts = list(
|
||||
@ -129,11 +133,19 @@ class TestIinfo(TestCase):
|
||||
with assert_raises((TypeError, ValueError)):
|
||||
iinfo("f4")
|
||||
|
||||
def test_unsigned_max(self):
|
||||
types = np.sctypes["uint"]
|
||||
for T in types:
|
||||
max_calculated = T(0) - T(1)
|
||||
assert_equal(iinfo(T).max, max_calculated)
|
||||
@parametrize(
|
||||
"T",
|
||||
[
|
||||
np.uint8,
|
||||
# xfail: unsupported add (uint[16,32,64])
|
||||
subtest(np.uint16, decorators=[xfail]),
|
||||
subtest(np.uint32, decorators=[xfail]),
|
||||
subtest(np.uint64, decorators=[xfail]),
|
||||
],
|
||||
)
|
||||
def test_unsigned_max(self, T):
|
||||
max_calculated = T(0) - T(1)
|
||||
assert_equal(iinfo(T).max, max_calculated)
|
||||
|
||||
|
||||
class TestRepr(TestCase):
|
||||
|
||||
@ -1833,7 +1833,7 @@ class TestMethods(TestCase):
|
||||
a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.unicode_)
|
||||
assert_equal(a.argsort(kind="m"), r)
|
||||
|
||||
@xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch")
|
||||
@xfail # (reason="TODO: searchsorted with nans differs in pytorch")
|
||||
@parametrize(
|
||||
"a",
|
||||
[
|
||||
@ -1905,7 +1905,7 @@ class TestMethods(TestCase):
|
||||
b = a.searchsorted([0, 1, 2], "right")
|
||||
assert_equal(b, [0, 2, 2])
|
||||
|
||||
@xpassIfTorchDynamo # (
|
||||
@xfail # (
|
||||
# reason="RuntimeError: self.storage_offset() must be divisible by 8"
|
||||
# )
|
||||
def test_searchsorted_unaligned_array(self):
|
||||
@ -1984,7 +1984,7 @@ class TestMethods(TestCase):
|
||||
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1, 0, 1, 2, 3])
|
||||
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4, 0, -1, 2, 3])
|
||||
|
||||
@xpassIfTorchDynamo # (reason="self.storage_offset() must be divisible by 8")
|
||||
@xfail # (reason="self.storage_offset() must be divisible by 8")
|
||||
def test_searchsorted_with_sorter(self):
|
||||
a = np.random.rand(300)
|
||||
s = a.argsort()
|
||||
@ -3713,7 +3713,14 @@ class TestTake(TestCase):
|
||||
y = np.take(x, [1, 2, 3], out=x[2:5], mode="wrap")
|
||||
assert_equal(y, np.array([1, 2, 3]))
|
||||
|
||||
@parametrize("shape", [(1, 2), (1,), ()])
|
||||
@parametrize(
|
||||
"shape",
|
||||
[
|
||||
subtest((1, 2)),
|
||||
subtest((1,)),
|
||||
subtest((), decorators=[skip("Sensitive to np version")]),
|
||||
],
|
||||
)
|
||||
def test_ret_is_out(self, shape):
|
||||
# 0d arrays should not be an exception to this rule
|
||||
x = np.arange(5)
|
||||
|
||||
@ -732,13 +732,16 @@ class TestAbs(TestCase):
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestBitShifts(TestCase):
|
||||
@parametrize("type_code", np.typecodes["Integer"] + "B")
|
||||
@parametrize("type_code", np.typecodes["AllInteger"])
|
||||
@parametrize("op", [operator.rshift, operator.lshift])
|
||||
def test_shift_all_bits(self, type_code, op):
|
||||
"""Shifts where the shift amount is the width of the type or wider"""
|
||||
# gh-2449
|
||||
dt = np.dtype(type_code)
|
||||
nbits = dt.itemsize * 8
|
||||
if dt in (np.dtype(np.uint64), np.dtype(np.uint32), np.dtype(np.uint16)):
|
||||
raise SkipTest("NYI: bitshift uint64")
|
||||
|
||||
for val in [5, -5]:
|
||||
for shift in [nbits, nbits + 4]:
|
||||
val_scl = np.array(val).astype(dt)[()]
|
||||
|
||||
@ -18,7 +18,7 @@ from torch.testing._internal.common_utils import (
|
||||
dtype_names = [
|
||||
"bool_",
|
||||
*[f"int{w}" for w in [8, 16, 32, 64]],
|
||||
"uint8",
|
||||
*[f"uint{w}" for w in [8, 16, 32, 64]],
|
||||
*[f"float{w}" for w in [16, 32, 64]],
|
||||
*[f"complex{w}" for w in [64, 128]],
|
||||
]
|
||||
|
||||
@ -1189,6 +1189,11 @@ class NumpyTypeInfoVariable(ConstantLikeVariable):
|
||||
class NumpyDTypeVariable(ConstantLikeVariable):
|
||||
_error_prefix = "np.dtype[...]"
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
if isinstance(value, tnp.DType):
|
||||
value = ConstantLikeVariable.np_dtype(value.name)
|
||||
super().__init__(value, **kwargs)
|
||||
|
||||
def as_proxy(self):
|
||||
"""Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
|
||||
|
||||
|
||||
@ -1089,6 +1089,7 @@ class NumpyNdarrayVariable(TensorVariable):
|
||||
|
||||
from ..utils import numpy_attr_wrapper
|
||||
from .builder import wrap_fx_proxy
|
||||
from .misc import NumpyDTypeVariable
|
||||
|
||||
result = None
|
||||
|
||||
@ -1135,6 +1136,8 @@ class NumpyNdarrayVariable(TensorVariable):
|
||||
if not has_free_symbols(r := example_ndarray.size):
|
||||
return ConstantVariable.create(int(r))
|
||||
return insert_into_graph()
|
||||
if name == "dtype":
|
||||
return NumpyDTypeVariable(example_ndarray.dtype)
|
||||
elif name in ["base", "flags", "dtype"]:
|
||||
unimplemented(f"TODO: add support for ndarray.{name}")
|
||||
elif name in ["__version__"]:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
|
||||
# These two dicts are autogenerated with autogen/gen_dtypes.py,
|
||||
# using numpy version 1.23.5.
|
||||
# using numpy version 1.24.3.
|
||||
|
||||
_can_cast_dict = {
|
||||
"no": {
|
||||
@ -14,6 +14,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -27,6 +30,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -40,6 +46,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -53,6 +62,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -66,6 +78,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -79,6 +94,57 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: True,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: True,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: True,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -92,6 +158,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -105,6 +174,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: False,
|
||||
@ -118,6 +190,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
@ -131,6 +206,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -144,6 +222,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -159,6 +240,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -172,6 +256,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -185,6 +272,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -198,6 +288,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -211,6 +304,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -224,6 +320,57 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: True,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: True,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: True,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: False,
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -237,6 +384,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -250,6 +400,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: False,
|
||||
@ -263,6 +416,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
@ -276,6 +432,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -289,6 +448,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: False,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -304,6 +466,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -317,6 +482,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -330,6 +498,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -343,6 +514,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -356,6 +530,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -369,12 +546,63 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: False,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: True,
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: False,
|
||||
torch.float32: False,
|
||||
torch.float64: True,
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: True,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
torch.int64: False,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.int8: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
@ -382,6 +610,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -395,6 +626,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -408,6 +642,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: True,
|
||||
@ -421,6 +658,9 @@ _can_cast_dict = {
|
||||
torch.complex64: False,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -434,6 +674,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -449,6 +692,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -462,6 +708,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -475,6 +724,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -488,6 +740,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -501,6 +756,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: False,
|
||||
torch.int16: False,
|
||||
torch.int32: False,
|
||||
@ -514,6 +772,57 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: False,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -527,6 +836,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -540,6 +852,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -553,6 +868,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -566,6 +884,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: False,
|
||||
torch.uint16: False,
|
||||
torch.uint32: False,
|
||||
torch.uint64: False,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -579,6 +900,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -594,6 +918,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -607,6 +934,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -620,6 +950,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -633,6 +966,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -646,6 +982,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -659,6 +998,57 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: True,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: True,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
torch.int64: True,
|
||||
torch.bool: True,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: True,
|
||||
torch.float32: True,
|
||||
torch.float64: True,
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -672,6 +1062,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -685,6 +1078,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -698,6 +1094,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -711,6 +1110,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -724,6 +1126,9 @@ _can_cast_dict = {
|
||||
torch.complex64: True,
|
||||
torch.complex128: True,
|
||||
torch.uint8: True,
|
||||
torch.uint16: True,
|
||||
torch.uint32: True,
|
||||
torch.uint64: True,
|
||||
torch.int8: True,
|
||||
torch.int16: True,
|
||||
torch.int32: True,
|
||||
@ -742,6 +1147,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.float16,
|
||||
torch.uint16: torch.float32,
|
||||
torch.uint32: torch.float64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.float16,
|
||||
torch.int16: torch.float32,
|
||||
torch.int32: torch.float64,
|
||||
@ -755,6 +1163,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.float32,
|
||||
torch.uint16: torch.float32,
|
||||
torch.uint32: torch.float64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.float32,
|
||||
torch.int16: torch.float32,
|
||||
torch.int32: torch.float64,
|
||||
@ -768,6 +1179,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.float64,
|
||||
torch.uint16: torch.float64,
|
||||
torch.uint32: torch.float64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.float64,
|
||||
torch.int16: torch.float64,
|
||||
torch.int32: torch.float64,
|
||||
@ -781,6 +1195,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.complex64,
|
||||
torch.uint16: torch.complex64,
|
||||
torch.uint32: torch.complex128,
|
||||
torch.uint64: torch.complex128,
|
||||
torch.int8: torch.complex64,
|
||||
torch.int16: torch.complex64,
|
||||
torch.int32: torch.complex128,
|
||||
@ -794,6 +1211,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.complex128,
|
||||
torch.uint16: torch.complex128,
|
||||
torch.uint32: torch.complex128,
|
||||
torch.uint64: torch.complex128,
|
||||
torch.int8: torch.complex128,
|
||||
torch.int16: torch.complex128,
|
||||
torch.int32: torch.complex128,
|
||||
@ -807,12 +1227,63 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint8,
|
||||
torch.uint16: torch.uint16,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int16,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
torch.int64: torch.int64,
|
||||
torch.bool: torch.uint8,
|
||||
},
|
||||
torch.uint16: {
|
||||
torch.float16: torch.float32,
|
||||
torch.float32: torch.float32,
|
||||
torch.float64: torch.float64,
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint16,
|
||||
torch.uint16: torch.uint16,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int32,
|
||||
torch.int16: torch.int32,
|
||||
torch.int32: torch.int32,
|
||||
torch.int64: torch.int64,
|
||||
torch.bool: torch.uint16,
|
||||
},
|
||||
torch.uint32: {
|
||||
torch.float16: torch.float64,
|
||||
torch.float32: torch.float64,
|
||||
torch.float64: torch.float64,
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint32,
|
||||
torch.uint16: torch.uint32,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int64,
|
||||
torch.int16: torch.int64,
|
||||
torch.int32: torch.int64,
|
||||
torch.int64: torch.int64,
|
||||
torch.bool: torch.uint32,
|
||||
},
|
||||
torch.uint64: {
|
||||
torch.float16: torch.float64,
|
||||
torch.float32: torch.float64,
|
||||
torch.float64: torch.float64,
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint64,
|
||||
torch.uint16: torch.uint64,
|
||||
torch.uint32: torch.uint64,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.float64,
|
||||
torch.int16: torch.float64,
|
||||
torch.int32: torch.float64,
|
||||
torch.int64: torch.float64,
|
||||
torch.bool: torch.uint64,
|
||||
},
|
||||
torch.int8: {
|
||||
torch.float16: torch.float16,
|
||||
torch.float32: torch.float32,
|
||||
@ -820,6 +1291,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int16,
|
||||
torch.uint16: torch.int32,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int8,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
@ -833,6 +1307,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int16,
|
||||
torch.uint16: torch.int32,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int16,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
@ -846,6 +1323,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int32,
|
||||
torch.uint16: torch.int32,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int32,
|
||||
torch.int16: torch.int32,
|
||||
torch.int32: torch.int32,
|
||||
@ -859,6 +1339,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex128,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.int64,
|
||||
torch.uint16: torch.int64,
|
||||
torch.uint32: torch.int64,
|
||||
torch.uint64: torch.float64,
|
||||
torch.int8: torch.int64,
|
||||
torch.int16: torch.int64,
|
||||
torch.int32: torch.int64,
|
||||
@ -872,6 +1355,9 @@ _result_type_dict = {
|
||||
torch.complex64: torch.complex64,
|
||||
torch.complex128: torch.complex128,
|
||||
torch.uint8: torch.uint8,
|
||||
torch.uint16: torch.uint16,
|
||||
torch.uint32: torch.uint32,
|
||||
torch.uint64: torch.uint64,
|
||||
torch.int8: torch.int8,
|
||||
torch.int16: torch.int16,
|
||||
torch.int32: torch.int32,
|
||||
|
||||
@ -113,6 +113,24 @@ class uint8(unsignedinteger):
|
||||
torch_dtype = torch.uint8
|
||||
|
||||
|
||||
class uint16(unsignedinteger):
|
||||
name = "uint16"
|
||||
typecode = "H"
|
||||
torch_dtype = torch.uint16
|
||||
|
||||
|
||||
class uint32(signedinteger):
|
||||
name = "uint32"
|
||||
typecode = "I"
|
||||
torch_dtype = torch.uint32
|
||||
|
||||
|
||||
class uint64(signedinteger):
|
||||
name = "uint64"
|
||||
typecode = "L"
|
||||
torch_dtype = torch.uint64
|
||||
|
||||
|
||||
# floating point
|
||||
|
||||
|
||||
@ -160,6 +178,7 @@ _name_aliases = {
|
||||
"byte": int8,
|
||||
"short": int16,
|
||||
"longlong": int64, # XXX: is this correct?
|
||||
"ulonglong": uint64,
|
||||
"ubyte": uint8,
|
||||
"half": float16,
|
||||
"single": float32,
|
||||
@ -180,7 +199,7 @@ for name, obj in _name_aliases.items():
|
||||
# cf tests/core/test_scalar_methods.py
|
||||
sctypes = {
|
||||
"int": [int8, int16, int32, int64],
|
||||
"uint": [uint8],
|
||||
"uint": [uint8, uint16, uint32, uint64],
|
||||
"float": [float16, float32, float64],
|
||||
"complex": [complex64, complex128],
|
||||
"others": [bool_],
|
||||
|
||||
@ -1500,31 +1500,31 @@ TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_
|
||||
|
||||
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
|
||||
numpy_to_torch_dtype_dict = {
|
||||
np.bool_ : torch.bool,
|
||||
np.uint8 : torch.uint8,
|
||||
np.uint16 : torch.uint16,
|
||||
np.uint32 : torch.uint32,
|
||||
np.uint64 : torch.uint64,
|
||||
np.int8 : torch.int8,
|
||||
np.int16 : torch.int16,
|
||||
np.int32 : torch.int32,
|
||||
np.int64 : torch.int64,
|
||||
np.float16 : torch.float16,
|
||||
np.float32 : torch.float32,
|
||||
np.float64 : torch.float64,
|
||||
np.complex64 : torch.complex64,
|
||||
np.complex128 : torch.complex128
|
||||
np.dtype(np.bool_) : torch.bool,
|
||||
np.dtype(np.uint8) : torch.uint8,
|
||||
np.dtype(np.uint16) : torch.uint16,
|
||||
np.dtype(np.uint32) : torch.uint32,
|
||||
np.dtype(np.uint64) : torch.uint64,
|
||||
np.dtype(np.int8) : torch.int8,
|
||||
np.dtype(np.int16) : torch.int16,
|
||||
np.dtype(np.int32) : torch.int32,
|
||||
np.dtype(np.int64) : torch.int64,
|
||||
np.dtype(np.float16) : torch.float16,
|
||||
np.dtype(np.float32) : torch.float32,
|
||||
np.dtype(np.float64) : torch.float64,
|
||||
np.dtype(np.complex64) : torch.complex64,
|
||||
np.dtype(np.complex128): torch.complex128
|
||||
}
|
||||
|
||||
|
||||
# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
|
||||
# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
|
||||
# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
|
||||
# numpy dtypes like np.float64 are not instances, but rather classes. This leads
|
||||
# to rather absurd cases like np.float64 != np.dtype("float64") but
|
||||
# np.dtype(np.float64) == np.dtype("float64") and
|
||||
# np.dtype(np.dtype("float64")) == np.dtype("float64"). Especially when
|
||||
# checking against a reference we can't be sure which variant we get, so we
|
||||
# simply apply the conversion.
|
||||
def numpy_to_torch_dtype(np_dtype):
|
||||
try:
|
||||
return numpy_to_torch_dtype_dict[np_dtype]
|
||||
except KeyError:
|
||||
return numpy_to_torch_dtype_dict[np_dtype.type]
|
||||
return numpy_to_torch_dtype_dict[np.dtype(np_dtype)]
|
||||
|
||||
|
||||
def has_corresponding_torch_dtype(np_dtype):
|
||||
|
||||
Reference in New Issue
Block a user