mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Clean up assert in dynamo [1/N] (#165430)
Fixes some part of #162852 and #164878. These two issues have some relationship though. * __->__ #165430 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165430 Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42 Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
633a3b7f67
commit
a88587348b
@ -2800,5 +2800,15 @@
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0278": [
|
||||
{
|
||||
"Gb_type": "Unsupported dict type for fromkeys()",
|
||||
"Context": "{user_cls.__name__}.fromkeys(): {args} {kwargs}",
|
||||
"Explanation": "Failed to call {user_cls.__name__}.fromkeys() because {user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict",
|
||||
"Hints": [
|
||||
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -636,6 +636,11 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
assert source is not None
|
||||
|
||||
|
||||
def raise_type_error_exc(tx: "InstructionTranslator", msg_str: str) -> None:
|
||||
msg = variables.ConstantVariable.create(msg_str)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
|
||||
def typestr(*objs):
|
||||
if len(objs) == 1:
|
||||
(obj,) = objs
|
||||
|
@ -81,7 +81,12 @@ from ..utils import (
|
||||
str_methods,
|
||||
tensortype_to_dtype,
|
||||
)
|
||||
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
|
||||
from .base import (
|
||||
AsPythonConstantNotImplementedError,
|
||||
raise_type_error_exc,
|
||||
ValueMutationNew,
|
||||
VariableTracker,
|
||||
)
|
||||
from .constant import ConstantVariable
|
||||
from .dicts import (
|
||||
ConstDictVariable,
|
||||
@ -1930,20 +1935,36 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_custom_dict_fromkeys(
|
||||
tx: "InstructionTranslator", user_cls, *args, **kwargs
|
||||
):
|
||||
assert user_cls in {dict, OrderedDict, defaultdict}
|
||||
if user_cls not in {dict, OrderedDict, defaultdict}:
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported dict type for fromkeys()",
|
||||
context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}",
|
||||
explanation=f"Failed to call {user_cls.__name__}.fromkeys() because "
|
||||
f"{user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict",
|
||||
hints=[
|
||||
f"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict.",
|
||||
],
|
||||
)
|
||||
if kwargs:
|
||||
# Only `OrderedDict.fromkeys` accepts `value` passed by keyword
|
||||
assert user_cls is OrderedDict
|
||||
assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
|
||||
if (
|
||||
user_cls is not OrderedDict
|
||||
or len(args) != 1
|
||||
or len(kwargs) != 1
|
||||
or "value" not in kwargs
|
||||
):
|
||||
raise_type_error_exc(
|
||||
tx, f"{user_cls.__name__}.fromkeys() takes no keyword arguments"
|
||||
)
|
||||
args = (*args, kwargs.pop("value"))
|
||||
if len(args) == 0:
|
||||
msg = ConstantVariable.create(
|
||||
"fromkeys expected at least 1 arguments, got 0"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
raise_type_error_exc(tx, "fromkeys expected at least 1 arguments, got 0")
|
||||
if len(args) == 1:
|
||||
args = (*args, ConstantVariable.create(None))
|
||||
assert len(args) == 2
|
||||
if len(args) != 2:
|
||||
raise_type_error_exc(
|
||||
tx, f"fromkeys expected at most 2 arguments, got {len(args)}"
|
||||
)
|
||||
arg, value = args
|
||||
DictVariableType = (
|
||||
ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
|
||||
@ -2039,7 +2060,11 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
if kwargs:
|
||||
assert len(kwargs) == 1 and "strict" in kwargs
|
||||
if not (len(kwargs) == 1 and "strict" in kwargs):
|
||||
raise_type_error_exc(
|
||||
tx,
|
||||
f"zip() should only have 'strict' keyword argument, but ({len(kwargs)} given)",
|
||||
)
|
||||
strict = kwargs.pop("strict", False)
|
||||
args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args]
|
||||
return variables.ZipVariable(
|
||||
|
@ -17,7 +17,7 @@ from torch._dynamo.source import AttrSource, GetItemSource
|
||||
from .. import graph_break_hints, variables
|
||||
from ..exc import raise_observed_exception, unimplemented_v2
|
||||
from ..utils import cmp_name_to_op_mapping, common_constant_types, istype, np
|
||||
from .base import VariableTracker
|
||||
from .base import raise_type_error_exc, VariableTracker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -149,7 +149,8 @@ its type to `common_constant_types`.
|
||||
tx, [self, *args], kwargs
|
||||
)
|
||||
elif name == "join" and istype(self.value, str):
|
||||
assert len(args) == 1 and len(kwargs) == 0
|
||||
if not (len(args) == 1 and len(kwargs) == 0):
|
||||
raise_type_error_exc(tx, "str.join() takes exactly one argument")
|
||||
arg_unpacked = args[0].force_unpack_var_sequence(tx)
|
||||
try:
|
||||
arg_const = [x.as_python_constant() for x in arg_unpacked]
|
||||
|
@ -44,7 +44,7 @@ from ..utils import (
|
||||
raise_args_mismatch,
|
||||
specialize_symnode,
|
||||
)
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
from .base import raise_type_error_exc, ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
|
||||
|
||||
@ -508,7 +508,11 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_unhashable(args[0])
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
assert not kwargs and len(args) == 2
|
||||
if kwargs or len(args) != 2:
|
||||
raise_type_error_exc(
|
||||
tx,
|
||||
f"dict.__setitem__ takes exactly two arguments ({len(args)} given)",
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items[Hashable(args[0])] = args[1]
|
||||
return ConstantVariable.create(None)
|
||||
|
@ -47,7 +47,7 @@ from ..utils import (
|
||||
range_iterator,
|
||||
set_example_value,
|
||||
)
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
from .base import raise_type_error_exc, ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .functions import UserFunctionVariable, UserMethodVariable
|
||||
from .iter import IteratorVariable
|
||||
@ -147,13 +147,11 @@ class BaseListVariable(VariableTracker):
|
||||
if name == "__getitem__":
|
||||
from .tensor import TensorVariable
|
||||
|
||||
if len(args) != 1:
|
||||
msg = ConstantVariable.create(
|
||||
f"{name} takes exactly one argument ({len(args)} given)"
|
||||
if kwargs or len(args) != 1:
|
||||
raise_type_error_exc(
|
||||
tx, f"{name} takes exactly one argument ({len(args)} given)"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
assert not kwargs and len(args) == 1
|
||||
if isinstance(args[0], TensorVariable):
|
||||
value = get_fake_value(args[0].as_proxy().node, tx)
|
||||
if value.constant is not None and value.constant.numel() == 1:
|
||||
@ -1115,11 +1113,15 @@ class SizeVariable(TupleVariable):
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if name == "__getitem__":
|
||||
assert not kwargs and len(args) == 1
|
||||
if kwargs or len(args) != 1:
|
||||
raise_type_error_exc(
|
||||
tx, f"{name} takes exactly one argument ({len(args)} given)"
|
||||
)
|
||||
out = self.get_item_dyn(tx, args[0])
|
||||
return out
|
||||
elif name == "numel":
|
||||
assert not args and not kwargs
|
||||
if args or kwargs:
|
||||
raise_type_error_exc(tx, f"{name} takes no arguments")
|
||||
return self.numel(tx)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
@ -60,7 +60,7 @@ from ..utils import (
|
||||
proxy_args_kwargs,
|
||||
tuple_methods,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .base import raise_type_error_exc, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .functions import NestedUserFunctionVariable, UserFunctionVariable
|
||||
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
|
||||
@ -1218,7 +1218,10 @@ class MethodWrapperVariable(VariableTracker):
|
||||
if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
|
||||
args[0], variables.TensorVariable
|
||||
):
|
||||
assert len(args) == 1 and len(kwargs) == 0
|
||||
if not (len(args) == 1 and len(kwargs) == 0):
|
||||
raise_type_error_exc(
|
||||
tx, "tensor attribute getter takes exactly one argument"
|
||||
)
|
||||
|
||||
return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
|
||||
|
||||
|
@ -69,7 +69,7 @@ from ..utils import (
|
||||
proxy_args_kwargs,
|
||||
unwrap_if_wrapper,
|
||||
)
|
||||
from .base import typestr, VariableTracker
|
||||
from .base import raise_type_error_exc, typestr, VariableTracker
|
||||
from .ctx_manager import (
|
||||
AutocastModeVariable,
|
||||
ProfilerContextVariable,
|
||||
@ -1179,7 +1179,11 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
def handle_push_torch_function(
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
assert len(args) == 1 and not kwargs
|
||||
if len(args) != 1 or kwargs:
|
||||
raise_type_error_exc(
|
||||
tx,
|
||||
f"push_torch_function takes exactly one argument ({len(args)} given)",
|
||||
)
|
||||
TorchFunctionModeStackVariable.register_mutation(tx)
|
||||
tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
|
||||
return ConstantVariable.create(None)
|
||||
@ -1188,14 +1192,19 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
def handle_len_torch_function(
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
assert not args and not kwargs
|
||||
if args or kwargs:
|
||||
raise_type_error_exc(tx, "len_torch_function_stack takes no arguments")
|
||||
return ConstantVariable.create(
|
||||
len(tx.symbolic_torch_function_state.mode_stack)
|
||||
)
|
||||
|
||||
@register(torch._C._get_function_stack_at)
|
||||
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
assert len(args) == 1 and not kwargs
|
||||
if len(args) != 1 or kwargs:
|
||||
raise_type_error_exc(
|
||||
tx,
|
||||
f"get_function_stack_at takes exactly one argument ({len(args)} given)",
|
||||
)
|
||||
ind = args[0].as_python_constant()
|
||||
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
|
||||
return tx.symbolic_torch_function_state.mode_stack[ind]
|
||||
|
Reference in New Issue
Block a user