[dynamo] Clean up assert in dynamo [1/N] (#165430)

Fixes some part of #162852 and #164878. These two issues have some relationship though.

* __->__ #165430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165430
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42

Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
This commit is contained in:
can-gaa-hou
2025-10-19 21:00:01 +00:00
committed by PyTorch MergeBot
parent 633a3b7f67
commit a88587348b
8 changed files with 87 additions and 28 deletions

View File

@ -2800,5 +2800,15 @@
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0278": [
{
"Gb_type": "Unsupported dict type for fromkeys()",
"Context": "{user_cls.__name__}.fromkeys(): {args} {kwargs}",
"Explanation": "Failed to call {user_cls.__name__}.fromkeys() because {user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict",
"Hints": [
"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict."
]
}
]
}

View File

@ -636,6 +636,11 @@ class VariableTracker(metaclass=VariableTrackerMeta):
assert source is not None
def raise_type_error_exc(tx: "InstructionTranslator", msg_str: str) -> None:
msg = variables.ConstantVariable.create(msg_str)
raise_observed_exception(TypeError, tx, args=[msg])
def typestr(*objs):
if len(objs) == 1:
(obj,) = objs

View File

@ -81,7 +81,12 @@ from ..utils import (
str_methods,
tensortype_to_dtype,
)
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
from .base import (
AsPythonConstantNotImplementedError,
raise_type_error_exc,
ValueMutationNew,
VariableTracker,
)
from .constant import ConstantVariable
from .dicts import (
ConstDictVariable,
@ -1930,20 +1935,36 @@ class BuiltinVariable(VariableTracker):
def call_custom_dict_fromkeys(
tx: "InstructionTranslator", user_cls, *args, **kwargs
):
assert user_cls in {dict, OrderedDict, defaultdict}
if user_cls not in {dict, OrderedDict, defaultdict}:
unimplemented_v2(
gb_type="Unsupported dict type for fromkeys()",
context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}",
explanation=f"Failed to call {user_cls.__name__}.fromkeys() because "
f"{user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict",
hints=[
f"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict.",
],
)
if kwargs:
# Only `OrderedDict.fromkeys` accepts `value` passed by keyword
assert user_cls is OrderedDict
assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
if (
user_cls is not OrderedDict
or len(args) != 1
or len(kwargs) != 1
or "value" not in kwargs
):
raise_type_error_exc(
tx, f"{user_cls.__name__}.fromkeys() takes no keyword arguments"
)
args = (*args, kwargs.pop("value"))
if len(args) == 0:
msg = ConstantVariable.create(
"fromkeys expected at least 1 arguments, got 0"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_type_error_exc(tx, "fromkeys expected at least 1 arguments, got 0")
if len(args) == 1:
args = (*args, ConstantVariable.create(None))
assert len(args) == 2
if len(args) != 2:
raise_type_error_exc(
tx, f"fromkeys expected at most 2 arguments, got {len(args)}"
)
arg, value = args
DictVariableType = (
ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
@ -2039,7 +2060,11 @@ class BuiltinVariable(VariableTracker):
def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
if kwargs:
assert len(kwargs) == 1 and "strict" in kwargs
if not (len(kwargs) == 1 and "strict" in kwargs):
raise_type_error_exc(
tx,
f"zip() should only have 'strict' keyword argument, but ({len(kwargs)} given)",
)
strict = kwargs.pop("strict", False)
args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args]
return variables.ZipVariable(

View File

@ -17,7 +17,7 @@ from torch._dynamo.source import AttrSource, GetItemSource
from .. import graph_break_hints, variables
from ..exc import raise_observed_exception, unimplemented_v2
from ..utils import cmp_name_to_op_mapping, common_constant_types, istype, np
from .base import VariableTracker
from .base import raise_type_error_exc, VariableTracker
if TYPE_CHECKING:
@ -149,7 +149,8 @@ its type to `common_constant_types`.
tx, [self, *args], kwargs
)
elif name == "join" and istype(self.value, str):
assert len(args) == 1 and len(kwargs) == 0
if not (len(args) == 1 and len(kwargs) == 0):
raise_type_error_exc(tx, "str.join() takes exactly one argument")
arg_unpacked = args[0].force_unpack_var_sequence(tx)
try:
arg_const = [x.as_python_constant() for x in arg_unpacked]

View File

@ -44,7 +44,7 @@ from ..utils import (
raise_args_mismatch,
specialize_symnode,
)
from .base import ValueMutationNew, VariableTracker
from .base import raise_type_error_exc, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
@ -508,7 +508,11 @@ class ConstDictVariable(VariableTracker):
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
assert not kwargs and len(args) == 2
if kwargs or len(args) != 2:
raise_type_error_exc(
tx,
f"dict.__setitem__ takes exactly two arguments ({len(args)} given)",
)
tx.output.side_effects.mutation(self)
self.items[Hashable(args[0])] = args[1]
return ConstantVariable.create(None)

View File

@ -47,7 +47,7 @@ from ..utils import (
range_iterator,
set_example_value,
)
from .base import ValueMutationNew, VariableTracker
from .base import raise_type_error_exc, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
from .functions import UserFunctionVariable, UserMethodVariable
from .iter import IteratorVariable
@ -147,13 +147,11 @@ class BaseListVariable(VariableTracker):
if name == "__getitem__":
from .tensor import TensorVariable
if len(args) != 1:
msg = ConstantVariable.create(
f"{name} takes exactly one argument ({len(args)} given)"
if kwargs or len(args) != 1:
raise_type_error_exc(
tx, f"{name} takes exactly one argument ({len(args)} given)"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert not kwargs and len(args) == 1
if isinstance(args[0], TensorVariable):
value = get_fake_value(args[0].as_proxy().node, tx)
if value.constant is not None and value.constant.numel() == 1:
@ -1115,11 +1113,15 @@ class SizeVariable(TupleVariable):
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__getitem__":
assert not kwargs and len(args) == 1
if kwargs or len(args) != 1:
raise_type_error_exc(
tx, f"{name} takes exactly one argument ({len(args)} given)"
)
out = self.get_item_dyn(tx, args[0])
return out
elif name == "numel":
assert not args and not kwargs
if args or kwargs:
raise_type_error_exc(tx, f"{name} takes no arguments")
return self.numel(tx)
return super().call_method(tx, name, args, kwargs)

View File

@ -60,7 +60,7 @@ from ..utils import (
proxy_args_kwargs,
tuple_methods,
)
from .base import VariableTracker
from .base import raise_type_error_exc, VariableTracker
from .constant import ConstantVariable
from .functions import NestedUserFunctionVariable, UserFunctionVariable
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
@ -1218,7 +1218,10 @@ class MethodWrapperVariable(VariableTracker):
if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
args[0], variables.TensorVariable
):
assert len(args) == 1 and len(kwargs) == 0
if not (len(args) == 1 and len(kwargs) == 0):
raise_type_error_exc(
tx, "tensor attribute getter takes exactly one argument"
)
return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)

View File

@ -69,7 +69,7 @@ from ..utils import (
proxy_args_kwargs,
unwrap_if_wrapper,
)
from .base import typestr, VariableTracker
from .base import raise_type_error_exc, typestr, VariableTracker
from .ctx_manager import (
AutocastModeVariable,
ProfilerContextVariable,
@ -1179,7 +1179,11 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
def handle_push_torch_function(
self, tx: "InstructionTranslator", *args, **kwargs
):
assert len(args) == 1 and not kwargs
if len(args) != 1 or kwargs:
raise_type_error_exc(
tx,
f"push_torch_function takes exactly one argument ({len(args)} given)",
)
TorchFunctionModeStackVariable.register_mutation(tx)
tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
return ConstantVariable.create(None)
@ -1188,14 +1192,19 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
def handle_len_torch_function(
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not args and not kwargs
if args or kwargs:
raise_type_error_exc(tx, "len_torch_function_stack takes no arguments")
return ConstantVariable.create(
len(tx.symbolic_torch_function_state.mode_stack)
)
@register(torch._C._get_function_stack_at)
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
assert len(args) == 1 and not kwargs
if len(args) != 1 or kwargs:
raise_type_error_exc(
tx,
f"get_function_stack_at takes exactly one argument ({len(args)} given)",
)
ind = args[0].as_python_constant()
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
return tx.symbolic_torch_function_state.mode_stack[ind]