mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dynamo] Fix np.issubdtype (#116459)"
This reverts commit b5c33ccdb3198a48a354e21a4fdace0ec6d04146. Reverted https://github.com/pytorch/pytorch/pull/116459 on behalf of https://github.com/zou3519 due to Broke CI, seems to be a landrace ([comment](https://github.com/pytorch/pytorch/pull/116459#issuecomment-1877135999))
This commit is contained in:
@ -1875,17 +1875,6 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
def test_numpy_subdtype(self):
|
||||
def fn(x, n):
|
||||
return np.issubdtype(type(n), np.integer) + x
|
||||
|
||||
args = [torch.randn(10), 4096]
|
||||
correct = fn(*args)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
self.assertEqual(opt_fn(*args), correct)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_numpy_take_along_axis(self):
|
||||
def fn(x, i, a):
|
||||
return np.take_along_axis(x, i, a)
|
||||
|
@ -16,7 +16,6 @@ from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
|
||||
from ..utils import (
|
||||
check_constant_args,
|
||||
check_unspec_python_args,
|
||||
identity,
|
||||
is_tensor_base_attr_getter,
|
||||
proxy_args_kwargs,
|
||||
@ -834,18 +833,10 @@ class NumpyVariable(VariableTracker):
|
||||
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
|
||||
"""
|
||||
|
||||
constant_fold_functions = (tnp.issubdtype,)
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
@classmethod
|
||||
def can_constant_fold_through(cls, fn):
|
||||
mod = fn.__module__.split(".")
|
||||
assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
|
||||
return fn in cls.constant_fold_functions
|
||||
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
@ -877,21 +868,8 @@ class NumpyVariable(VariableTracker):
|
||||
msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
|
||||
unimplemented(msg)
|
||||
|
||||
constant_args = check_constant_args(args, kwargs)
|
||||
unspec_python_args = check_unspec_python_args(args, kwargs)
|
||||
|
||||
if self.can_constant_fold_through(func) and (
|
||||
constant_args or unspec_python_args
|
||||
):
|
||||
# constant fold
|
||||
return variables.ConstantVariable.create(
|
||||
self.as_python_constant()(
|
||||
*[x.as_python_constant() for x in args],
|
||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||
),
|
||||
)
|
||||
|
||||
# TODO Add all the functions that go from constants to constants to can_constant_fold_through
|
||||
# TODO(larryliu0820): currently assuming all numpy.* functions are returning a ndarray that can be
|
||||
# wrapped by NumpyNdarrayVariable which is wrong!
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function",
|
||||
numpy_to_tensor_wrapper(func),
|
||||
@ -915,11 +893,18 @@ class NumpyVariable(VariableTracker):
|
||||
return self.value
|
||||
|
||||
def as_proxy(self):
|
||||
# this handles numpy dtype attribute such as np.float32. TODO(larryliu0820): we should split NumpyVariable
|
||||
# into NumpyVariable for instances/objects and NumpyVariable for types.
|
||||
if config.trace_numpy and isinstance(self.value, type):
|
||||
# This handles numpy dtype attributes such as np.float32
|
||||
# We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
|
||||
# In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
|
||||
return self.value.__name__
|
||||
# retrieve attribute str. E.g., "float32" if given np.float32
|
||||
|
||||
attr = self.value.__name__
|
||||
# get tnp equivalent
|
||||
tnp_dtype = tnp.dtype(attr)
|
||||
# returning a string here because we are assuming all `dtype` kwargs for numpy
|
||||
# functions can take an equivalent string and the behavior of the function would
|
||||
# be the same as taking a numpy dtype.
|
||||
return tnp_dtype.name
|
||||
|
||||
return super().as_proxy()
|
||||
|
||||
|
@ -12,7 +12,9 @@ from . import _dtypes_impl
|
||||
|
||||
|
||||
class generic:
|
||||
name = "generic"
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __new__(cls, value):
|
||||
# NumPy scalars are modelled as 0-D arrays
|
||||
@ -35,44 +37,33 @@ class generic:
|
||||
|
||||
|
||||
class number(generic):
|
||||
name = "number"
|
||||
pass
|
||||
|
||||
|
||||
class integer(number):
|
||||
name = "integer"
|
||||
pass
|
||||
|
||||
|
||||
class inexact(number):
|
||||
name = "inexact"
|
||||
pass
|
||||
|
||||
|
||||
class signedinteger(integer):
|
||||
name = "signedinteger"
|
||||
pass
|
||||
|
||||
|
||||
class unsignedinteger(integer):
|
||||
name = "unsignedinteger"
|
||||
pass
|
||||
|
||||
|
||||
class floating(inexact):
|
||||
name = "floating"
|
||||
pass
|
||||
|
||||
|
||||
class complexfloating(inexact):
|
||||
name = "complexfloating"
|
||||
pass
|
||||
|
||||
|
||||
_abstract_dtypes = [
|
||||
"generic",
|
||||
"number",
|
||||
"integer",
|
||||
"signedinteger",
|
||||
"unsignedinteger",
|
||||
"inexact",
|
||||
"floating",
|
||||
"complexfloating",
|
||||
]
|
||||
|
||||
# ##### concrete types
|
||||
|
||||
# signed integers
|
||||
@ -408,17 +399,6 @@ def issubclass_(arg, klass):
|
||||
|
||||
def issubdtype(arg1, arg2):
|
||||
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
|
||||
|
||||
# We also accept strings even if NumPy doesn't as dtypes are serialized as their
|
||||
# string representation in dynamo's graph
|
||||
def str_to_abstract(t):
|
||||
if isinstance(t, str) and t in _abstract_dtypes:
|
||||
return globals()[t]
|
||||
return t
|
||||
|
||||
arg1 = str_to_abstract(arg1)
|
||||
arg2 = str_to_abstract(arg2)
|
||||
|
||||
if not issubclass_(arg1, generic):
|
||||
arg1 = dtype(arg1).type
|
||||
if not issubclass_(arg2, generic):
|
||||
@ -426,7 +406,17 @@ def issubdtype(arg1, arg2):
|
||||
return issubclass(arg1, arg2)
|
||||
|
||||
|
||||
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"]
|
||||
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype"]
|
||||
__all__ += list(_names.keys()) # noqa: PLE0605
|
||||
__all__ += list(_name_aliases.keys()) # noqa: PLE0605
|
||||
__all__ += _abstract_dtypes # noqa: PLE0605
|
||||
__all__ += [ # noqa: PLE0605
|
||||
"sctypes",
|
||||
"generic",
|
||||
"number",
|
||||
"integer",
|
||||
"signedinteger",
|
||||
"unsignedinteger",
|
||||
"inexact",
|
||||
"floating",
|
||||
"complexfloating",
|
||||
]
|
||||
|
Reference in New Issue
Block a user