Revert "[dynamo] Fix np.issubdtype (#116459)"

This reverts commit b5c33ccdb3198a48a354e21a4fdace0ec6d04146.

Reverted https://github.com/pytorch/pytorch/pull/116459 on behalf of https://github.com/zou3519 due to Broke CI, seems to be a landrace ([comment](https://github.com/pytorch/pytorch/pull/116459#issuecomment-1877135999))
This commit is contained in:
PyTorch MergeBot
2024-01-04 14:00:11 +00:00
parent 3a0f6897c5
commit 75dae4f691
3 changed files with 35 additions and 71 deletions

View File

@ -12,7 +12,9 @@ from . import _dtypes_impl
class generic:
name = "generic"
@property
def name(self):
return self.__class__.__name__
def __new__(cls, value):
# NumPy scalars are modelled as 0-D arrays
@ -35,44 +37,33 @@ class generic:
class number(generic):
name = "number"
pass
class integer(number):
name = "integer"
pass
class inexact(number):
name = "inexact"
pass
class signedinteger(integer):
name = "signedinteger"
pass
class unsignedinteger(integer):
name = "unsignedinteger"
pass
class floating(inexact):
name = "floating"
pass
class complexfloating(inexact):
name = "complexfloating"
pass
_abstract_dtypes = [
"generic",
"number",
"integer",
"signedinteger",
"unsignedinteger",
"inexact",
"floating",
"complexfloating",
]
# ##### concrete types
# signed integers
@ -408,17 +399,6 @@ def issubclass_(arg, klass):
def issubdtype(arg1, arg2):
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
# We also accept strings even if NumPy doesn't as dtypes are serialized as their
# string representation in dynamo's graph
def str_to_abstract(t):
if isinstance(t, str) and t in _abstract_dtypes:
return globals()[t]
return t
arg1 = str_to_abstract(arg1)
arg2 = str_to_abstract(arg2)
if not issubclass_(arg1, generic):
arg1 = dtype(arg1).type
if not issubclass_(arg2, generic):
@ -426,7 +406,17 @@ def issubdtype(arg1, arg2):
return issubclass(arg1, arg2)
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"]
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype"]
__all__ += list(_names.keys()) # noqa: PLE0605
__all__ += list(_name_aliases.keys()) # noqa: PLE0605
__all__ += _abstract_dtypes # noqa: PLE0605
__all__ += [ # noqa: PLE0605
"sctypes",
"generic",
"number",
"integer",
"signedinteger",
"unsignedinteger",
"inexact",
"floating",
"complexfloating",
]