Revert "[dynamo] Fix np.issubdtype (#116459)"

This reverts commit b5c33ccdb3198a48a354e21a4fdace0ec6d04146.

Reverted https://github.com/pytorch/pytorch/pull/116459 on behalf of https://github.com/zou3519 due to Broke CI, seems to be a landrace ([comment](https://github.com/pytorch/pytorch/pull/116459#issuecomment-1877135999))
This commit is contained in:
PyTorch MergeBot
2024-01-04 14:00:11 +00:00
parent 3a0f6897c5
commit 75dae4f691
3 changed files with 35 additions and 71 deletions

View File

@ -1875,17 +1875,6 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
def test_numpy_subdtype(self):
def fn(x, n):
return np.issubdtype(type(n), np.integer) + x
args = [torch.randn(10), 4096]
correct = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
self.assertEqual(opt_fn(*args), correct)
self.assertEqual(cnts.frame_count, 1)
def test_numpy_take_along_axis(self):
def fn(x, i, a):
return np.take_along_axis(x, i, a)

View File

@ -16,7 +16,6 @@ from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
from ..utils import (
check_constant_args,
check_unspec_python_args,
identity,
is_tensor_base_attr_getter,
proxy_args_kwargs,
@ -834,18 +833,10 @@ class NumpyVariable(VariableTracker):
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
"""
constant_fold_functions = (tnp.issubdtype,)
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
@classmethod
def can_constant_fold_through(cls, fn):
mod = fn.__module__.split(".")
assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
return fn in cls.constant_fold_functions
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
@ -877,21 +868,8 @@ class NumpyVariable(VariableTracker):
msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
unimplemented(msg)
constant_args = check_constant_args(args, kwargs)
unspec_python_args = check_unspec_python_args(args, kwargs)
if self.can_constant_fold_through(func) and (
constant_args or unspec_python_args
):
# constant fold
return variables.ConstantVariable.create(
self.as_python_constant()(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
),
)
# TODO Add all the functions that go from constants to constants to can_constant_fold_through
# TODO(larryliu0820): currently assuming all numpy.* functions are returning a ndarray that can be
# wrapped by NumpyNdarrayVariable which is wrong!
proxy = tx.output.create_proxy(
"call_function",
numpy_to_tensor_wrapper(func),
@ -915,11 +893,18 @@ class NumpyVariable(VariableTracker):
return self.value
def as_proxy(self):
# this handles numpy dtype attribute such as np.float32. TODO(larryliu0820): we should split NumpyVariable
# into NumpyVariable for instances/objects and NumpyVariable for types.
if config.trace_numpy and isinstance(self.value, type):
# This handles numpy dtype attributes such as np.float32
# We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
# In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
return self.value.__name__
# retrieve attribute str. E.g., "float32" if given np.float32
attr = self.value.__name__
# get tnp equivalent
tnp_dtype = tnp.dtype(attr)
# returning a string here because we are assuming all `dtype` kwargs for numpy
# functions can take an equivalent string and the behavior of the function would
# be the same as taking a numpy dtype.
return tnp_dtype.name
return super().as_proxy()

View File

@ -12,7 +12,9 @@ from . import _dtypes_impl
class generic:
name = "generic"
@property
def name(self):
return self.__class__.__name__
def __new__(cls, value):
# NumPy scalars are modelled as 0-D arrays
@ -35,44 +37,33 @@ class generic:
class number(generic):
name = "number"
pass
class integer(number):
name = "integer"
pass
class inexact(number):
name = "inexact"
pass
class signedinteger(integer):
name = "signedinteger"
pass
class unsignedinteger(integer):
name = "unsignedinteger"
pass
class floating(inexact):
name = "floating"
pass
class complexfloating(inexact):
name = "complexfloating"
pass
_abstract_dtypes = [
"generic",
"number",
"integer",
"signedinteger",
"unsignedinteger",
"inexact",
"floating",
"complexfloating",
]
# ##### concrete types
# signed integers
@ -408,17 +399,6 @@ def issubclass_(arg, klass):
def issubdtype(arg1, arg2):
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
# We also accept strings even if NumPy doesn't as dtypes are serialized as their
# string representation in dynamo's graph
def str_to_abstract(t):
if isinstance(t, str) and t in _abstract_dtypes:
return globals()[t]
return t
arg1 = str_to_abstract(arg1)
arg2 = str_to_abstract(arg2)
if not issubclass_(arg1, generic):
arg1 = dtype(arg1).type
if not issubclass_(arg2, generic):
@ -426,7 +406,17 @@ def issubdtype(arg1, arg2):
return issubclass(arg1, arg2)
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"]
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype"]
__all__ += list(_names.keys()) # noqa: PLE0605
__all__ += list(_name_aliases.keys()) # noqa: PLE0605
__all__ += _abstract_dtypes # noqa: PLE0605
__all__ += [ # noqa: PLE0605
"sctypes",
"generic",
"number",
"integer",
"signedinteger",
"unsignedinteger",
"inexact",
"floating",
"complexfloating",
]