[dynamo] Support ndarray.dtype attribute access (#124490)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124490
Approved by: https://github.com/lezcano
ghstack dependencies: #125717
This commit is contained in:
Andrew M. James
2024-06-04 17:56:39 +00:00
committed by PyTorch MergeBot
parent a9cc147fa1
commit 4adee71155
7 changed files with 46 additions and 29 deletions

View File

@ -1619,6 +1619,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
dt = np.dtype("float")
return np.full_like(x, 2.4, dtype=dt)
@make_test
def test_numpy_dtype_attr(x):
return np.ones_like(x).dtype == x.dtype
@make_test
def test_numpy_linalg(x):
return np.linalg.norm(x.numpy(), axis=0)

View File

@ -121,9 +121,7 @@ class TestBinaryUfuncs(TestCase):
def _helper_reference_numerics(
expected, actual, msg, exact_dtype, equal_nan=True
):
if not torch.can_cast(
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
):
if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype], dtype):
exact_dtype = False
if dtype is torch.bfloat16 and expected.dtype == np.float32:

View File

@ -184,7 +184,7 @@ class TestUnaryUfuncs(TestCase):
expected, actual, msg, exact_dtype, equal_nan=True
):
if not torch.can_cast(
numpy_to_torch_dtype_dict[expected.dtype.type], dtype
numpy_to_torch_dtype_dict[expected.dtype], dtype
):
exact_dtype = False

View File

@ -1833,7 +1833,7 @@ class TestMethods(TestCase):
a = np.array(["aaaaaaaaa" for i in range(100)], dtype=np.unicode_)
assert_equal(a.argsort(kind="m"), r)
@xpassIfTorchDynamo # (reason="TODO: searchsorted with nans differs in pytorch")
@xfail # (reason="TODO: searchsorted with nans differs in pytorch")
@parametrize(
"a",
[
@ -1905,7 +1905,7 @@ class TestMethods(TestCase):
b = a.searchsorted([0, 1, 2], "right")
assert_equal(b, [0, 2, 2])
@xpassIfTorchDynamo # (
@xfail # (
# reason="RuntimeError: self.storage_offset() must be divisible by 8"
# )
def test_searchsorted_unaligned_array(self):
@ -1984,7 +1984,7 @@ class TestMethods(TestCase):
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1, 0, 1, 2, 3])
# assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4, 0, -1, 2, 3])
@xpassIfTorchDynamo # (reason="self.storage_offset() must be divisible by 8")
@xfail # (reason="self.storage_offset() must be divisible by 8")
def test_searchsorted_with_sorter(self):
a = np.random.rand(300)
s = a.argsort()
@ -3713,7 +3713,14 @@ class TestTake(TestCase):
y = np.take(x, [1, 2, 3], out=x[2:5], mode="wrap")
assert_equal(y, np.array([1, 2, 3]))
@parametrize("shape", [(1, 2), (1,), ()])
@parametrize(
"shape",
[
subtest((1, 2)),
subtest((1,)),
subtest((), decorators=[skip("Sensitive to np version")]),
],
)
def test_ret_is_out(self, shape):
# 0d arrays should not be an exception to this rule
x = np.arange(5)

View File

@ -1189,6 +1189,11 @@ class NumpyTypeInfoVariable(ConstantLikeVariable):
class NumpyDTypeVariable(ConstantLikeVariable):
_error_prefix = "np.dtype[...]"
def __init__(self, value, **kwargs):
if isinstance(value, tnp.DType):
value = ConstantLikeVariable.np_dtype(value.name)
super().__init__(value, **kwargs)
def as_proxy(self):
"""Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:

View File

@ -1089,6 +1089,7 @@ class NumpyNdarrayVariable(TensorVariable):
from ..utils import numpy_attr_wrapper
from .builder import wrap_fx_proxy
from .misc import NumpyDTypeVariable
result = None
@ -1135,6 +1136,8 @@ class NumpyNdarrayVariable(TensorVariable):
if not has_free_symbols(r := example_ndarray.size):
return ConstantVariable.create(int(r))
return insert_into_graph()
if name == "dtype":
return NumpyDTypeVariable(example_ndarray.dtype)
elif name in ["base", "flags", "dtype"]:
unimplemented(f"TODO: add support for ndarray.{name}")
elif name in ["__version__"]:

View File

@ -1500,31 +1500,31 @@ TestEnvironment.def_flag("TEST_CUDA_MEM_LEAK_CHECK", env_var="PYTORCH_TEST_CUDA_
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
numpy_to_torch_dtype_dict = {
np.bool_ : torch.bool,
np.uint8 : torch.uint8,
np.uint16 : torch.uint16,
np.uint32 : torch.uint32,
np.uint64 : torch.uint64,
np.int8 : torch.int8,
np.int16 : torch.int16,
np.int32 : torch.int32,
np.int64 : torch.int64,
np.float16 : torch.float16,
np.float32 : torch.float32,
np.float64 : torch.float64,
np.complex64 : torch.complex64,
np.complex128 : torch.complex128
np.dtype(np.bool_) : torch.bool,
np.dtype(np.uint8) : torch.uint8,
np.dtype(np.uint16) : torch.uint16,
np.dtype(np.uint32) : torch.uint32,
np.dtype(np.uint64) : torch.uint64,
np.dtype(np.int8) : torch.int8,
np.dtype(np.int16) : torch.int16,
np.dtype(np.int32) : torch.int32,
np.dtype(np.int64) : torch.int64,
np.dtype(np.float16) : torch.float16,
np.dtype(np.float32) : torch.float32,
np.dtype(np.float64) : torch.float64,
np.dtype(np.complex64) : torch.complex64,
np.dtype(np.complex128): torch.complex128
}
# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
# numpy dtypes like np.float64 are not instances, but rather classes. This leads
# to rather absurd cases like np.float64 != np.dtype("float64") but
# np.dtype(np.float64) == np.dtype("float64") and
# np.dtype(np.dtype("float64")) == np.dtype("float64"). Especially when
# checking against a reference we can't be sure which variant we get, so we
# simply apply the conversion.
def numpy_to_torch_dtype(np_dtype):
try:
return numpy_to_torch_dtype_dict[np_dtype]
except KeyError:
return numpy_to_torch_dtype_dict[np_dtype.type]
return numpy_to_torch_dtype_dict[np.dtype(np_dtype)]
def has_corresponding_torch_dtype(np_dtype):