mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support ndarray.dtype attribute access (#124490)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124490 Approved by: https://github.com/lezcano ghstack dependencies: #125717
This commit is contained in:
committed by
PyTorch MergeBot
parent
a9cc147fa1
commit
4adee71155
@ -1619,6 +1619,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
dt = np.dtype("float")
|
||||
return np.full_like(x, 2.4, dtype=dt)
|
||||
|
||||
@make_test
|
||||
def test_numpy_dtype_attr(x):
|
||||
return np.ones_like(x).dtype == x.dtype
|
||||
|
||||
@make_test
|
||||
def test_numpy_linalg(x):
|
||||
return np.linalg.norm(x.numpy(), axis=0)
|
||||
|
@ -121,9 +121,7 @@ class TestBinaryUfuncs(TestCase):
|
||||
def _helper_reference_numerics(
|
||||
expected, actual, msg, exact_dtype, equal_nan=True
|
||||
):
|
||||
if not torch.can_cast(
|
||||
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
|
||||
):
|
||||
if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype], dtype):
|
||||
exact_dtype = False
|
||||
|
||||
if dtype is torch.bfloat16 and expected.dtype == np.float32:
|
||||
|
@ -184,7 +184,7 @@ class TestUnaryUfuncs(TestCase):
|
||||
expected, actual, msg, exact_dtype, equal_nan=True
|
||||
):
|
||||
if not torch.can_cast(
|
||||
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
|
||||
numpy_to_torch_dtype_dict[expected.dtype], dtype
|
||||
):
|
||||
exact_dtype = False
|
||||
|
||||
|
@ -1833,7 +1833,7 @@ class TestMethods(TestCase):
|
||||
a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.unicode_)
|
||||
assert_equal(a.argsort(kind="m"), r)
|
||||
|
||||
@xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch")
|
||||
@xfail # (reason="TODO: searchsorted with nans differs in pytorch")
|
||||
@parametrize(
|
||||
"a",
|
||||
[
|
||||
@ -1905,7 +1905,7 @@ class TestMethods(TestCase):
|
||||
b = a.searchsorted([0, 1, 2], "right")
|
||||
assert_equal(b, [0, 2, 2])
|
||||
|
||||
@xpassIfTorchDynamo # (
|
||||
@xfail # (
|
||||
# reason="RuntimeError: self.storage_offset() must be divisible by 8"
|
||||
# )
|
||||
def test_searchsorted_unaligned_array(self):
|
||||
@ -1984,7 +1984,7 @@ class TestMethods(TestCase):
|
||||
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1, 0, 1, 2, 3])
|
||||
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4, 0, -1, 2, 3])
|
||||
|
||||
@xpassIfTorchDynamo # (reason="self.storage_offset() must be divisible by 8")
|
||||
@xfail # (reason="self.storage_offset() must be divisible by 8")
|
||||
def test_searchsorted_with_sorter(self):
|
||||
a = np.random.rand(300)
|
||||
s = a.argsort()
|
||||
@ -3713,7 +3713,14 @@ class TestTake(TestCase):
|
||||
y = np.take(x, [1, 2, 3], out=x[2:5], mode="wrap")
|
||||
assert_equal(y, np.array([1, 2, 3]))
|
||||
|
||||
@parametrize("shape", [(1, 2), (1,), ()])
|
||||
@parametrize(
|
||||
"shape",
|
||||
[
|
||||
subtest((1, 2)),
|
||||
subtest((1,)),
|
||||
subtest((), decorators=[skip("Sensitive to np version")]),
|
||||
],
|
||||
)
|
||||
def test_ret_is_out(self, shape):
|
||||
# 0d arrays should not be an exception to this rule
|
||||
x = np.arange(5)
|
||||
|
@ -1189,6 +1189,11 @@ class NumpyTypeInfoVariable(ConstantLikeVariable):
|
||||
class NumpyDTypeVariable(ConstantLikeVariable):
|
||||
_error_prefix = "np.dtype[...]"
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
if isinstance(value, tnp.DType):
|
||||
value = ConstantLikeVariable.np_dtype(value.name)
|
||||
super().__init__(value, **kwargs)
|
||||
|
||||
def as_proxy(self):
|
||||
"""Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
|
||||
|
||||
|
@ -1089,6 +1089,7 @@ class NumpyNdarrayVariable(TensorVariable):
|
||||
|
||||
from ..utils import numpy_attr_wrapper
|
||||
from .builder import wrap_fx_proxy
|
||||
from .misc import NumpyDTypeVariable
|
||||
|
||||
result = None
|
||||
|
||||
@ -1135,6 +1136,8 @@ class NumpyNdarrayVariable(TensorVariable):
|
||||
if not has_free_symbols(r := example_ndarray.size):
|
||||
return ConstantVariable.create(int(r))
|
||||
return insert_into_graph()
|
||||
if name == "dtype":
|
||||
return NumpyDTypeVariable(example_ndarray.dtype)
|
||||
elif name in ["base", "flags", "dtype"]:
|
||||
unimplemented(f"TODO: add support for ndarray.{name}")
|
||||
elif name in ["__version__"]:
|
||||
|
@ -1500,31 +1500,31 @@ TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_
|
||||
|
||||
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
|
||||
numpy_to_torch_dtype_dict = {
|
||||
np.bool_ : torch.bool,
|
||||
np.uint8 : torch.uint8,
|
||||
np.uint16 : torch.uint16,
|
||||
np.uint32 : torch.uint32,
|
||||
np.uint64 : torch.uint64,
|
||||
np.int8 : torch.int8,
|
||||
np.int16 : torch.int16,
|
||||
np.int32 : torch.int32,
|
||||
np.int64 : torch.int64,
|
||||
np.float16 : torch.float16,
|
||||
np.float32 : torch.float32,
|
||||
np.float64 : torch.float64,
|
||||
np.complex64 : torch.complex64,
|
||||
np.complex128 : torch.complex128
|
||||
np.dtype(np.bool_) : torch.bool,
|
||||
np.dtype(np.uint8) : torch.uint8,
|
||||
np.dtype(np.uint16) : torch.uint16,
|
||||
np.dtype(np.uint32) : torch.uint32,
|
||||
np.dtype(np.uint64) : torch.uint64,
|
||||
np.dtype(np.int8) : torch.int8,
|
||||
np.dtype(np.int16) : torch.int16,
|
||||
np.dtype(np.int32) : torch.int32,
|
||||
np.dtype(np.int64) : torch.int64,
|
||||
np.dtype(np.float16) : torch.float16,
|
||||
np.dtype(np.float32) : torch.float32,
|
||||
np.dtype(np.float64) : torch.float64,
|
||||
np.dtype(np.complex64) : torch.complex64,
|
||||
np.dtype(np.complex128): torch.complex128
|
||||
}
|
||||
|
||||
|
||||
# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
|
||||
# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
|
||||
# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
|
||||
# numpy dtypes like np.float64 are not instances, but rather classes. This leads
|
||||
# to rather absurd cases like np.float64 != np.dtype("float64") but
|
||||
# np.dtype(np.float64) == np.dtype("float64") and
|
||||
# np.dtype(np.dtype("float64")) == np.dtype("float64"). Especially when
|
||||
# checking against a reference we can't be sure which variant we get, so we
|
||||
# simply apply the conversion.
|
||||
def numpy_to_torch_dtype(np_dtype):
|
||||
try:
|
||||
return numpy_to_torch_dtype_dict[np_dtype]
|
||||
except KeyError:
|
||||
return numpy_to_torch_dtype_dict[np_dtype.type]
|
||||
return numpy_to_torch_dtype_dict[np.dtype(np_dtype)]
|
||||
|
||||
|
||||
def has_corresponding_torch_dtype(np_dtype):
|
||||
|
Reference in New Issue
Block a user