[list] Raise exception in invalid list method call (#156148)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156148
Approved by: https://github.com/zou3519
ghstack dependencies: #153969
This commit is contained in:
Guilherme Leobas
2025-07-05 15:54:44 -03:00
committed by PyTorch MergeBot
parent 034e996d37
commit e49acfc5c5
24 changed files with 438 additions and 56 deletions

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py
index dbc5ef4f9f2..2b9f3b9311f 100644
index dbc5ef4f9f2..239b75f74cc 100644
--- a/test/dynamo/cpython/3_13/list_tests.py
+++ b/test/dynamo/cpython/3_13/list_tests.py
@@ -1,3 +1,53 @@
@ -65,3 +65,14 @@ index dbc5ef4f9f2..2b9f3b9311f 100644
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit
@@ -119,10 +169,6 @@ class CommonTest(seq_tests.CommonTest):
a[-1] = 9
self.assertEqual(a, self.type2test([5,6,7,8,9]))
- msg = "list indices must be integers or slices"
- with self.assertRaisesRegex(TypeError, msg):
- a['a'] = "python"
-
def test_delitem(self):
a = self.type2test([0, 1])
del a[1]

View File

@ -169,10 +169,6 @@ class CommonTest(seq_tests.CommonTest):
a[-1] = 9
self.assertEqual(a, self.type2test([5,6,7,8,9]))
msg = "list indices must be integers or slices"
with self.assertRaisesRegex(TypeError, msg):
a['a'] = "python"
def test_delitem(self):
a = self.type2test([0, 1])
del a[1]

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py
index 23ef902aa0b..30e69ff75bd 100644
index 23ef902aa0b..6e4c6d99d16 100644
--- a/test/dynamo/cpython/3_13/test_list.py
+++ b/test/dynamo/cpython/3_13/test_list.py
@@ -1,6 +1,57 @@
@ -61,7 +61,14 @@ index 23ef902aa0b..30e69ff75bd 100644
from test.support import cpython_only
from test.support.script_helper import assert_python_ok
import pickle
@@ -324,6 +375,7 @@ class ListTest(list_tests.CommonTest):
@@ -35,8 +86,6 @@ class ListTest(list_tests.CommonTest):
# Note: This test is expected to SEGV under Cygwin 1.3.12 or
# earlier due to a newlib bug. See the following mailing list
# thread for the details:
self.assertRaises(MemoryError, list, range(sys.maxsize // 2))
# This code used to segfault in Py2.4a3
@@ -324,6 +373,7 @@ class ListTest(list_tests.CommonTest):
a.append(4)
self.assertEqual(list(it), [])
@ -69,7 +76,7 @@ index 23ef902aa0b..30e69ff75bd 100644
def test_deopt_from_append_list(self):
# gh-132011: it used to crash, because
# of `CALL_LIST_APPEND` specialization failure.
@@ -345,4 +397,4 @@ class ListTest(list_tests.CommonTest):
@@ -345,4 +395,4 @@ class ListTest(list_tests.CommonTest):
self.assertEqual(rc, 0)
if __name__ == "__main__":

305
test/dynamo/test_list.py Normal file
View File

@ -0,0 +1,305 @@
# Owner(s): ["module: dynamo"]
# TODO: move set tests from test_functions.py/test_misc.py to this file
import unittest
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import make_dynamo_test
class TupleTests(torch._dynamo.test_case.TestCase):
# Tuple methods
# + count
# + index
# BinOps:
# +, <, >, <=, >=, ==, !=
# Dunder methods:
# + __getitem__
# + __contains__
# + __delitem__
thetype = tuple
def setUp(self):
self.old = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
super().setUp()
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self.old
return super().tearDown()
def assertEqual(self, a, b):
return self.assertTrue(a == b, f"{a} != {b}")
def assertNotEqual(self, x, y, msg=None, *, atol=None, rtol=None, **kwargs):
return self.assertTrue(x != y, f"{x} == {y}")
@make_dynamo_test
def test_count(self):
p = self.thetype("abcab")
self.assertEqual(p.count("a"), 2)
self.assertEqual(p.count("ab"), 0)
# Wrong number of arguments
self.assertRaises(TypeError, p.count)
self.assertRaises(TypeError, p.count, 2, 3)
@make_dynamo_test
def test_index(self):
p = self.thetype("abc")
self.assertEqual(p.index("a"), 0)
self.assertRaises(ValueError, p.index, "e")
# Wrong number of arguments
self.assertRaises(TypeError, p.index)
@unittest.expectedFailure
@make_dynamo_test
def test_binop_add(self):
p, q = map(self.thetype, ["abc", "bcd"])
self.assertIsInstance(p + q, self.thetype)
self.assertEqual(p + q, self.thetype("abcbcd"))
self.assertEqual(p.__add__(q), self.thetype("abcbcd"))
# Wrong number of arguments
self.assertRaises(TypeError, p.__add__)
@make_dynamo_test
def test_cmp_eq(self):
p, q, r = map(self.thetype, ["ab", "abc", "ab"])
self.assertTrue(p == p)
self.assertTrue(p == r)
self.assertEqual(p, p)
self.assertEqual(p, r)
self.assertNotEqual(p, q)
self.assertTrue(p.__eq__(r))
# Wrong number of arguments
self.assertRaises(TypeError, p.__eq__)
@make_dynamo_test
def test_cmp_ne(self):
p, q = map(self.thetype, ["ab", "abc"])
self.assertTrue(p != q)
self.assertNotEqual(p, q)
self.assertTrue(p.__ne__(q))
# Wrong number of arguments
self.assertRaises(TypeError, p.__ne__)
@make_dynamo_test
def test_cmp_less_than(self):
p, q = map(self.thetype, ["ab", "abc"])
self.assertTrue(p < q)
self.assertTrue(p.__lt__(q))
self.assertFalse(q < p)
# Wrong number of arguments
self.assertRaises(TypeError, p.__lt__)
@make_dynamo_test
def test_cmp_greater_than(self):
p, q = map(self.thetype, ["ab", "abc"])
self.assertTrue(q > p)
self.assertTrue(q.__gt__(p))
self.assertFalse(p > q)
# Wrong number of arguments
self.assertRaises(TypeError, p.__gt__)
@make_dynamo_test
def test_cmp_less_than_or_equal(self):
p, q = map(self.thetype, ["ab", "abc"])
self.assertTrue(p <= q)
self.assertTrue(p.__le__(q))
self.assertFalse(q <= p)
# Wrong number of arguments
self.assertRaises(TypeError, p.__le__)
@make_dynamo_test
def test_cmp_greater_than_or_equal(self):
p, q = map(self.thetype, ["ab", "abc"])
self.assertTrue(q >= p)
self.assertTrue(q.__ge__(p))
self.assertFalse(p >= q)
# Wrong number of arguments
self.assertRaises(TypeError, p.__ge__)
@make_dynamo_test
def test___getitem__(self):
p = self.thetype("abc")
self.assertEqual(p.__getitem__(2), "c")
self.assertRaises(IndexError, p.__getitem__, 10)
# Wrong number of arguments
self.assertRaises(TypeError, p.__getitem__)
self.assertRaises(TypeError, p.__getitem__, 1, 2)
@make_dynamo_test
def test___contains__(self):
p = self.thetype("abc")
self.assertTrue(p.__contains__("a"))
self.assertIsInstance(p.__contains__("c"), bool)
# Wrong number of arguments
self.assertRaises(TypeError, p.__contains__)
self.assertRaises(TypeError, p.__contains__, 1, 2)
@unittest.expectedFailure
@make_dynamo_test
def test___delitem__(self):
p = self.thetype("abc")
self.assertIsNone(p.__delitem__(1))
self.assertEqual(p, self.thetype("ac"))
# Wrong number of arguments
self.assertRaises(TypeError, p.__delitem__)
self.assertRaises(TypeError, p.__delitem__, 1, 2)
class ListTests(TupleTests):
# List methods
# + append
# + copy
# + clear
# + extend
# + insert
# + pop
# + remove
# + reverse
# + sort
# BinOps:
# +, <, >, <=, >=, ==, !=
# Dunder methods:
# + __setitem__
# + __getitem__
# + __contains__
# + __delitem__
thetype = list
@make_dynamo_test
def test_append(self):
p = self.thetype("abc")
self.assertIsNone(p.append("d"))
self.assertEqual(p, ["a", "b", "c", "d"])
# Wrong number of arguments
self.assertRaises(TypeError, p.append)
self.assertRaises(TypeError, p.append, 2, 3)
@make_dynamo_test
def test_copy(self):
p = self.thetype("abc")
self.assertEqual(p.copy(), p)
# Wrong number of arguments
self.assertRaises(TypeError, p.copy, 1)
@make_dynamo_test
def test_clear(self):
p = self.thetype("abc")
self.assertIsNone(p.clear())
self.assertEqual(p, [])
self.assertEqual(len(p), 0)
# Wrong number of arguments
self.assertRaises(TypeError, p.clear, 1)
@make_dynamo_test
def test_extend(self):
p, q = map(self.thetype, ["ab", "cd"])
self.assertIsNone(p.extend(q))
self.assertEqual(p, self.thetype("abcd"))
# extend needs an iterable
self.assertRaises(TypeError, p.extend, 1)
# Wrong number of arguments
self.assertRaises(TypeError, p.extend)
self.assertRaises(TypeError, p.extend, 2, 3)
@make_dynamo_test
def test_insert(self):
p = self.thetype("abc")
self.assertIsNone(p.insert(1, "ef"))
self.assertEqual(p, ["a", "ef", "b", "c"])
# Wrong number of arguments
self.assertRaises(TypeError, p.insert)
self.assertRaises(TypeError, p.insert, 1)
self.assertRaises(TypeError, p.insert, 1, 2, 3)
@make_dynamo_test
def test_pop(self):
p = self.thetype("abcd")
self.assertEqual(p.pop(), "d")
self.assertEqual(p.pop(1), "b")
self.assertRaises(IndexError, p.pop, 10)
# Wrong number of arguments
self.assertRaises(TypeError, p.pop, 2, 3)
@unittest.expectedFailure
@make_dynamo_test
def test_remove(self):
p = self.thetype("abad")
self.assertIsNone(p.remove("a"))
self.assertEqual(p, ["b", "a", "d"])
self.assertRaises(ValueError, p.remove, "x")
# Wrong number of arguments
self.assertRaises(TypeError, p.remove)
self.assertRaises(TypeError, p.remove, 2, 3)
@make_dynamo_test
def test_reverse(self):
p = self.thetype("abcd")
self.assertIsNone(p.reverse())
self.assertEqual(p, self.thetype("dcba"))
# Wrong number of arguments
self.assertRaises(TypeError, p.reverse, 1)
@make_dynamo_test
def test_sort(self):
p = self.thetype("dbca")
self.assertIsNone(p.sort())
self.assertEqual(p, self.thetype("abcd"))
@unittest.expectedFailure
@make_dynamo_test
def test_binop_iadd(self):
p, q = map(self.thetype, ["abc", "bcd"])
r = p.__iadd__(q)
self.assertIsInstance(r, self.thetype)
self.assertEqual(r, self.thetype("abcbcd"))
self.assertEqual(p, self.thetype("abcbcd"))
# Wrong number of arguments
self.assertRaises(TypeError, p.__iadd__)
@make_dynamo_test
def test___setitem__(self):
p = self.thetype("abc")
self.assertIsNone(p.__setitem__(2, "a"))
self.assertEqual(p, self.thetype("aba"))
p[0:] = []
self.assertEqual(p, [])
# Wrong number of arguments
self.assertRaises(TypeError, p.__setitem__)
self.assertRaises(TypeError, p.__setitem__, 1)
self.assertRaises(TypeError, p.__setitem__, 1, 2, 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -2624,6 +2624,22 @@ def _get_fake_tensor(vt):
return fake_tensor
def slice_length(s: slice, seq_len: int) -> int:
start, stop, step = s.indices(seq_len)
return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)
def raise_args_mismatch(tx, name):
from torch._dynamo.exc import raise_observed_exception
from torch._dynamo.variables import ConstantVariable
raise_observed_exception(
TypeError,
tx,
args=[ConstantVariable(f"wrong number of arguments for {name}() call")],
)
def iter_contains(items, search, tx, check_tensor_identity=False):
from .variables import (
BuiltinVariable,

View File

@ -549,10 +549,17 @@ class BuiltinVariable(VariableTracker):
def expand_list_like(tx: "InstructionTranslator", lst, const):
if isinstance(lst, ConstantVariable):
lst, const = const, lst
return lst.__class__(
items=lst.items * const.as_python_constant(),
mutation_type=ValueMutationNew(),
)
try:
return lst.__class__(
items=lst.items * const.as_python_constant(),
mutation_type=ValueMutationNew(),
)
except MemoryError as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
list_like_expansion_handlers: list[
tuple[

View File

@ -41,6 +41,7 @@ from ..utils import (
dict_keys,
dict_values,
istype,
raise_args_mismatch,
specialize_symnode,
)
from .base import ValueMutationNew, VariableTracker
@ -57,14 +58,6 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def raise_args_mismatch(tx, name):
raise_observed_exception(
TypeError,
tx,
args=[ConstantVariable(f"wrong number of arguments for {name}() call")],
)
def was_instancecheck_override(obj):
return type(obj).__dict__.get("__instancecheck__", False)

View File

@ -37,6 +37,7 @@ from ..utils import (
Lit,
namedtuple_fields,
odict_values,
raise_args_mismatch,
set_example_value,
)
from .base import ValueMutationNew, VariableTracker
@ -108,6 +109,9 @@ class BaseListVariable(VariableTracker):
index = arg.as_python_constant()
if isinstance(index, slice):
if index.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
# Set source to None because slicing a list gives a new local
return self.clone(
items=self.items[index],
@ -137,8 +141,10 @@ class BaseListVariable(VariableTracker):
from .tensor import TensorVariable
if len(args) != 1:
msg = f"{name} takes exactly one argument ({len(args)} given)"
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
msg = ConstantVariable.create(
f"{name} takes exactly one argument ({len(args)} given)"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert not kwargs and len(args) == 1
if isinstance(args[0], TensorVariable):
@ -159,14 +165,17 @@ class BaseListVariable(VariableTracker):
if value.python_type() not in (int, slice):
msg = f"indices must be integers or slices, not {value.python_type()}"
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)])
return self.getitem_const(tx, value)
elif name == "__contains__":
assert len(args) == 1
assert not kwargs
if len(args) != 1 or kwargs:
raise_args_mismatch(tx, name)
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
elif name == "index":
if not len(args):
raise_args_mismatch(tx, name)
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.index),
[self] + list(args),
@ -174,14 +183,16 @@ class BaseListVariable(VariableTracker):
)
elif name == "count":
if len(args) != 1:
msg = f"{name} takes exactly one argument ({len(args)} given)"
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
raise_args_mismatch(tx, name)
return VariableTracker.build(tx, operator.countOf).call_function(
tx,
[self, args[0]],
kwargs,
)
elif name in cmp_name_to_op_mapping:
if len(args) != 1:
raise_args_mismatch(tx, name)
left = self
right = args[0]
# TODO this type check logic mirrors the following
@ -397,24 +408,28 @@ class CommonListMethodsVariable(BaseListVariable):
if name == "append" and self.is_mutable():
assert not kwargs
if len(args) != 1:
raise_args_mismatch(tx, name)
(arg,) = args
tx.output.side_effects.mutation(self)
self.items.append(arg)
return ConstantVariable.create(None)
elif (
name == "extend"
and self.is_mutable()
and args
and args[0].has_force_unpack_var_sequence(tx)
):
assert not kwargs
elif name == "extend" and self.is_mutable():
if len(args) != 1 or kwargs:
raise_args_mismatch(tx, name)
if not args[0].has_force_unpack_var_sequence(tx):
msg = ConstantVariable.create(f"{type(args[0])} object is not iterable")
raise_observed_exception(TypeError, tx, args=[msg])
(arg,) = args
arg.force_apply_to_var_sequence(
tx, lambda item: self.call_method(tx, "append", [item], {})
)
return ConstantVariable.create(None)
elif name == "insert" and self.is_mutable():
assert not kwargs
if kwargs or len(args) != 2:
raise_args_mismatch(tx, name)
idx, value = args
if isinstance(idx, SymNodeVariable):
const_idx = idx.evaluate_expr()
@ -425,10 +440,23 @@ class CommonListMethodsVariable(BaseListVariable):
return ConstantVariable.create(None)
elif name == "pop" and self.is_mutable():
assert not kwargs
if kwargs or len(args) > 1:
raise_args_mismatch(tx, name)
if len(self.items) == 0:
msg = ConstantVariable.create("pop from empty list")
raise_observed_exception(IndexError, tx, args=[msg])
if len(args):
idx = args[0].as_python_constant()
if idx > len(self.items):
msg = ConstantVariable.create("pop index out of range")
raise_observed_exception(IndexError, tx, args=[msg])
tx.output.side_effects.mutation(self)
return self.items.pop(*[a.as_python_constant() for a in args])
elif name == "clear" and self.is_mutable():
assert not kwargs and not args
if args or kwargs:
raise_observed_exception(TypeError, tx)
tx.output.side_effects.mutation(self)
self.items.clear()
return ConstantVariable.create(None)
@ -471,13 +499,13 @@ class CommonListMethodsVariable(BaseListVariable):
return ConstantVariable.create(None)
elif name == "copy":
# List copy() doesn't have args and kwargs
assert not kwargs
assert not args
if args or kwargs:
raise_args_mismatch(tx, name)
items = list(self.items)
return self.modified(items, mutation_type=ValueMutationNew())
elif name == "reverse" and self.is_mutable():
assert not kwargs
assert not args
if args or kwargs:
raise_args_mismatch(tx, name)
self.items.reverse()
tx.output.side_effects.mutation(self)
return ConstantVariable.create(None)
@ -506,28 +534,47 @@ class ListVariable(CommonListMethodsVariable):
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if (
name == "__setitem__"
and self.is_mutable()
and args
and args[0].is_python_constant()
):
assert not kwargs
from .tensor import SymNodeVariable
if name == "__setitem__" and self.is_mutable():
if kwargs or len(args) != 2:
raise_args_mismatch(tx, name)
key, value = args
if not key.is_python_constant():
# probably will graph-break
super().call_method(tx, name, args, kwargs)
tx.output.side_effects.mutation(self)
if isinstance(key, SliceVariable):
if not value.has_force_unpack_var_sequence(tx):
unimplemented_v2(
gb_type="Unsupported conversion for slice assignment",
context=f"call_method {self} {name} {args}",
explanation=f"Missing dynamo support for converting {value} into a list for slice assignment.",
hints=[*graph_break_hints.SUPPORTABLE],
msg = ConstantVariable.create("can only assign an iterable")
raise_observed_exception(TypeError, tx, args=[msg])
key = key.as_python_constant()
if key.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
value = value.force_unpack_var_sequence(tx)
try:
self.items[key] = value
except Exception as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
tx
)
else:
self.items[key.as_python_constant()] = value
if isinstance(key, SymNodeVariable):
key = key.evaluate_expr()
else:
key = key.as_python_constant()
if key >= len(self.items) or key < -len(self.items):
msg = ConstantVariable.create("list index out of range")
raise_observed_exception(IndexError, tx, args=[msg])
self.items[key] = value
return ConstantVariable.create(None)
if name == "sort" and self.is_mutable():