mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[list] Raise exception in invalid list method call (#156148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156148 Approved by: https://github.com/zou3519 ghstack dependencies: #153969
This commit is contained in:
committed by
PyTorch MergeBot
parent
034e996d37
commit
e49acfc5c5
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py
|
||||
index dbc5ef4f9f2..2b9f3b9311f 100644
|
||||
index dbc5ef4f9f2..239b75f74cc 100644
|
||||
--- a/test/dynamo/cpython/3_13/list_tests.py
|
||||
+++ b/test/dynamo/cpython/3_13/list_tests.py
|
||||
@@ -1,3 +1,53 @@
|
||||
@ -65,3 +65,14 @@ index dbc5ef4f9f2..2b9f3b9311f 100644
|
||||
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit
|
||||
|
||||
|
||||
@@ -119,10 +169,6 @@ class CommonTest(seq_tests.CommonTest):
|
||||
a[-1] = 9
|
||||
self.assertEqual(a, self.type2test([5,6,7,8,9]))
|
||||
|
||||
- msg = "list indices must be integers or slices"
|
||||
- with self.assertRaisesRegex(TypeError, msg):
|
||||
- a['a'] = "python"
|
||||
-
|
||||
def test_delitem(self):
|
||||
a = self.type2test([0, 1])
|
||||
del a[1]
|
||||
|
@ -169,10 +169,6 @@ class CommonTest(seq_tests.CommonTest):
|
||||
a[-1] = 9
|
||||
self.assertEqual(a, self.type2test([5,6,7,8,9]))
|
||||
|
||||
msg = "list indices must be integers or slices"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
a['a'] = "python"
|
||||
|
||||
def test_delitem(self):
|
||||
a = self.type2test([0, 1])
|
||||
del a[1]
|
||||
|
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py
|
||||
index 23ef902aa0b..30e69ff75bd 100644
|
||||
index 23ef902aa0b..6e4c6d99d16 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_list.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_list.py
|
||||
@@ -1,6 +1,57 @@
|
||||
@ -61,7 +61,14 @@ index 23ef902aa0b..30e69ff75bd 100644
|
||||
from test.support import cpython_only
|
||||
from test.support.script_helper import assert_python_ok
|
||||
import pickle
|
||||
@@ -324,6 +375,7 @@ class ListTest(list_tests.CommonTest):
|
||||
@@ -35,8 +86,6 @@ class ListTest(list_tests.CommonTest):
|
||||
# Note: This test is expected to SEGV under Cygwin 1.3.12 or
|
||||
# earlier due to a newlib bug. See the following mailing list
|
||||
# thread for the details:
|
||||
self.assertRaises(MemoryError, list, range(sys.maxsize // 2))
|
||||
|
||||
# This code used to segfault in Py2.4a3
|
||||
@@ -324,6 +373,7 @@ class ListTest(list_tests.CommonTest):
|
||||
a.append(4)
|
||||
self.assertEqual(list(it), [])
|
||||
|
||||
@ -69,7 +76,7 @@ index 23ef902aa0b..30e69ff75bd 100644
|
||||
def test_deopt_from_append_list(self):
|
||||
# gh-132011: it used to crash, because
|
||||
# of `CALL_LIST_APPEND` specialization failure.
|
||||
@@ -345,4 +397,4 @@ class ListTest(list_tests.CommonTest):
|
||||
@@ -345,4 +395,4 @@ class ListTest(list_tests.CommonTest):
|
||||
self.assertEqual(rc, 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
305
test/dynamo/test_list.py
Normal file
305
test/dynamo/test_list.py
Normal file
@ -0,0 +1,305 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# TODO: move set tests from test_functions.py/test_misc.py to this file
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
class TupleTests(torch._dynamo.test_case.TestCase):
|
||||
# Tuple methods
|
||||
# + count
|
||||
# + index
|
||||
# BinOps:
|
||||
# +, <, >, <=, >=, ==, !=
|
||||
# Dunder methods:
|
||||
# + __getitem__
|
||||
# + __contains__
|
||||
# + __delitem__
|
||||
|
||||
thetype = tuple
|
||||
|
||||
def setUp(self):
|
||||
self.old = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self.old
|
||||
return super().tearDown()
|
||||
|
||||
def assertEqual(self, a, b):
|
||||
return self.assertTrue(a == b, f"{a} != {b}")
|
||||
|
||||
def assertNotEqual(self, x, y, msg=None, *, atol=None, rtol=None, **kwargs):
|
||||
return self.assertTrue(x != y, f"{x} == {y}")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_count(self):
|
||||
p = self.thetype("abcab")
|
||||
self.assertEqual(p.count("a"), 2)
|
||||
self.assertEqual(p.count("ab"), 0)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.count)
|
||||
self.assertRaises(TypeError, p.count, 2, 3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_index(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertEqual(p.index("a"), 0)
|
||||
self.assertRaises(ValueError, p.index, "e")
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.index)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_binop_add(self):
|
||||
p, q = map(self.thetype, ["abc", "bcd"])
|
||||
self.assertIsInstance(p + q, self.thetype)
|
||||
self.assertEqual(p + q, self.thetype("abcbcd"))
|
||||
self.assertEqual(p.__add__(q), self.thetype("abcbcd"))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__add__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cmp_eq(self):
|
||||
p, q, r = map(self.thetype, ["ab", "abc", "ab"])
|
||||
self.assertTrue(p == p)
|
||||
self.assertTrue(p == r)
|
||||
self.assertEqual(p, p)
|
||||
self.assertEqual(p, r)
|
||||
self.assertNotEqual(p, q)
|
||||
self.assertTrue(p.__eq__(r))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__eq__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cmp_ne(self):
|
||||
p, q = map(self.thetype, ["ab", "abc"])
|
||||
self.assertTrue(p != q)
|
||||
self.assertNotEqual(p, q)
|
||||
self.assertTrue(p.__ne__(q))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__ne__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cmp_less_than(self):
|
||||
p, q = map(self.thetype, ["ab", "abc"])
|
||||
self.assertTrue(p < q)
|
||||
self.assertTrue(p.__lt__(q))
|
||||
self.assertFalse(q < p)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__lt__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cmp_greater_than(self):
|
||||
p, q = map(self.thetype, ["ab", "abc"])
|
||||
self.assertTrue(q > p)
|
||||
self.assertTrue(q.__gt__(p))
|
||||
self.assertFalse(p > q)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__gt__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cmp_less_than_or_equal(self):
|
||||
p, q = map(self.thetype, ["ab", "abc"])
|
||||
self.assertTrue(p <= q)
|
||||
self.assertTrue(p.__le__(q))
|
||||
self.assertFalse(q <= p)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__le__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cmp_greater_than_or_equal(self):
|
||||
p, q = map(self.thetype, ["ab", "abc"])
|
||||
self.assertTrue(q >= p)
|
||||
self.assertTrue(q.__ge__(p))
|
||||
self.assertFalse(p >= q)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__ge__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test___getitem__(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertEqual(p.__getitem__(2), "c")
|
||||
self.assertRaises(IndexError, p.__getitem__, 10)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__getitem__)
|
||||
self.assertRaises(TypeError, p.__getitem__, 1, 2)
|
||||
|
||||
@make_dynamo_test
|
||||
def test___contains__(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertTrue(p.__contains__("a"))
|
||||
self.assertIsInstance(p.__contains__("c"), bool)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__contains__)
|
||||
self.assertRaises(TypeError, p.__contains__, 1, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test___delitem__(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertIsNone(p.__delitem__(1))
|
||||
self.assertEqual(p, self.thetype("ac"))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__delitem__)
|
||||
self.assertRaises(TypeError, p.__delitem__, 1, 2)
|
||||
|
||||
|
||||
class ListTests(TupleTests):
|
||||
# List methods
|
||||
# + append
|
||||
# + copy
|
||||
# + clear
|
||||
# + extend
|
||||
# + insert
|
||||
# + pop
|
||||
# + remove
|
||||
# + reverse
|
||||
# + sort
|
||||
# BinOps:
|
||||
# +, <, >, <=, >=, ==, !=
|
||||
# Dunder methods:
|
||||
# + __setitem__
|
||||
# + __getitem__
|
||||
# + __contains__
|
||||
# + __delitem__
|
||||
|
||||
thetype = list
|
||||
|
||||
@make_dynamo_test
|
||||
def test_append(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertIsNone(p.append("d"))
|
||||
self.assertEqual(p, ["a", "b", "c", "d"])
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.append)
|
||||
self.assertRaises(TypeError, p.append, 2, 3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_copy(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertEqual(p.copy(), p)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.copy, 1)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_clear(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertIsNone(p.clear())
|
||||
self.assertEqual(p, [])
|
||||
self.assertEqual(len(p), 0)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.clear, 1)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_extend(self):
|
||||
p, q = map(self.thetype, ["ab", "cd"])
|
||||
self.assertIsNone(p.extend(q))
|
||||
self.assertEqual(p, self.thetype("abcd"))
|
||||
|
||||
# extend needs an iterable
|
||||
self.assertRaises(TypeError, p.extend, 1)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.extend)
|
||||
self.assertRaises(TypeError, p.extend, 2, 3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_insert(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertIsNone(p.insert(1, "ef"))
|
||||
self.assertEqual(p, ["a", "ef", "b", "c"])
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.insert)
|
||||
self.assertRaises(TypeError, p.insert, 1)
|
||||
self.assertRaises(TypeError, p.insert, 1, 2, 3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_pop(self):
|
||||
p = self.thetype("abcd")
|
||||
self.assertEqual(p.pop(), "d")
|
||||
self.assertEqual(p.pop(1), "b")
|
||||
self.assertRaises(IndexError, p.pop, 10)
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.pop, 2, 3)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_remove(self):
|
||||
p = self.thetype("abad")
|
||||
self.assertIsNone(p.remove("a"))
|
||||
self.assertEqual(p, ["b", "a", "d"])
|
||||
self.assertRaises(ValueError, p.remove, "x")
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.remove)
|
||||
self.assertRaises(TypeError, p.remove, 2, 3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_reverse(self):
|
||||
p = self.thetype("abcd")
|
||||
self.assertIsNone(p.reverse())
|
||||
self.assertEqual(p, self.thetype("dcba"))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.reverse, 1)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_sort(self):
|
||||
p = self.thetype("dbca")
|
||||
self.assertIsNone(p.sort())
|
||||
self.assertEqual(p, self.thetype("abcd"))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_binop_iadd(self):
|
||||
p, q = map(self.thetype, ["abc", "bcd"])
|
||||
r = p.__iadd__(q)
|
||||
self.assertIsInstance(r, self.thetype)
|
||||
self.assertEqual(r, self.thetype("abcbcd"))
|
||||
self.assertEqual(p, self.thetype("abcbcd"))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__iadd__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test___setitem__(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertIsNone(p.__setitem__(2, "a"))
|
||||
self.assertEqual(p, self.thetype("aba"))
|
||||
|
||||
p[0:] = []
|
||||
self.assertEqual(p, [])
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__setitem__)
|
||||
self.assertRaises(TypeError, p.__setitem__, 1)
|
||||
self.assertRaises(TypeError, p.__setitem__, 1, 2, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -2624,6 +2624,22 @@ def _get_fake_tensor(vt):
|
||||
return fake_tensor
|
||||
|
||||
|
||||
def slice_length(s: slice, seq_len: int) -> int:
|
||||
start, stop, step = s.indices(seq_len)
|
||||
return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)
|
||||
|
||||
|
||||
def raise_args_mismatch(tx, name):
|
||||
from torch._dynamo.exc import raise_observed_exception
|
||||
from torch._dynamo.variables import ConstantVariable
|
||||
|
||||
raise_observed_exception(
|
||||
TypeError,
|
||||
tx,
|
||||
args=[ConstantVariable(f"wrong number of arguments for {name}() call")],
|
||||
)
|
||||
|
||||
|
||||
def iter_contains(items, search, tx, check_tensor_identity=False):
|
||||
from .variables import (
|
||||
BuiltinVariable,
|
||||
|
@ -549,10 +549,17 @@ class BuiltinVariable(VariableTracker):
|
||||
def expand_list_like(tx: "InstructionTranslator", lst, const):
|
||||
if isinstance(lst, ConstantVariable):
|
||||
lst, const = const, lst
|
||||
return lst.__class__(
|
||||
items=lst.items * const.as_python_constant(),
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
try:
|
||||
return lst.__class__(
|
||||
items=lst.items * const.as_python_constant(),
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
except MemoryError as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
|
||||
list_like_expansion_handlers: list[
|
||||
tuple[
|
||||
|
@ -41,6 +41,7 @@ from ..utils import (
|
||||
dict_keys,
|
||||
dict_values,
|
||||
istype,
|
||||
raise_args_mismatch,
|
||||
specialize_symnode,
|
||||
)
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
@ -57,14 +58,6 @@ if TYPE_CHECKING:
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
||||
|
||||
def raise_args_mismatch(tx, name):
|
||||
raise_observed_exception(
|
||||
TypeError,
|
||||
tx,
|
||||
args=[ConstantVariable(f"wrong number of arguments for {name}() call")],
|
||||
)
|
||||
|
||||
|
||||
def was_instancecheck_override(obj):
|
||||
return type(obj).__dict__.get("__instancecheck__", False)
|
||||
|
||||
|
@ -37,6 +37,7 @@ from ..utils import (
|
||||
Lit,
|
||||
namedtuple_fields,
|
||||
odict_values,
|
||||
raise_args_mismatch,
|
||||
set_example_value,
|
||||
)
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
@ -108,6 +109,9 @@ class BaseListVariable(VariableTracker):
|
||||
index = arg.as_python_constant()
|
||||
|
||||
if isinstance(index, slice):
|
||||
if index.step == 0:
|
||||
msg = ConstantVariable.create("slice step cannot be zero")
|
||||
raise_observed_exception(ValueError, tx, args=[msg])
|
||||
# Set source to None because slicing a list gives a new local
|
||||
return self.clone(
|
||||
items=self.items[index],
|
||||
@ -137,8 +141,10 @@ class BaseListVariable(VariableTracker):
|
||||
from .tensor import TensorVariable
|
||||
|
||||
if len(args) != 1:
|
||||
msg = f"{name} takes exactly one argument ({len(args)} given)"
|
||||
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
|
||||
msg = ConstantVariable.create(
|
||||
f"{name} takes exactly one argument ({len(args)} given)"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
assert not kwargs and len(args) == 1
|
||||
if isinstance(args[0], TensorVariable):
|
||||
@ -159,14 +165,17 @@ class BaseListVariable(VariableTracker):
|
||||
|
||||
if value.python_type() not in (int, slice):
|
||||
msg = f"indices must be integers or slices, not {value.python_type()}"
|
||||
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
|
||||
raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)])
|
||||
|
||||
return self.getitem_const(tx, value)
|
||||
elif name == "__contains__":
|
||||
assert len(args) == 1
|
||||
assert not kwargs
|
||||
if len(args) != 1 or kwargs:
|
||||
raise_args_mismatch(tx, name)
|
||||
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
|
||||
elif name == "index":
|
||||
if not len(args):
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.index),
|
||||
[self] + list(args),
|
||||
@ -174,14 +183,16 @@ class BaseListVariable(VariableTracker):
|
||||
)
|
||||
elif name == "count":
|
||||
if len(args) != 1:
|
||||
msg = f"{name} takes exactly one argument ({len(args)} given)"
|
||||
raise_observed_exception(TypeError, tx, [ConstantVariable(msg)])
|
||||
raise_args_mismatch(tx, name)
|
||||
return VariableTracker.build(tx, operator.countOf).call_function(
|
||||
tx,
|
||||
[self, args[0]],
|
||||
kwargs,
|
||||
)
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
left = self
|
||||
right = args[0]
|
||||
# TODO this type check logic mirrors the following
|
||||
@ -397,24 +408,28 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
|
||||
if name == "append" and self.is_mutable():
|
||||
assert not kwargs
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
(arg,) = args
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.append(arg)
|
||||
return ConstantVariable.create(None)
|
||||
elif (
|
||||
name == "extend"
|
||||
and self.is_mutable()
|
||||
and args
|
||||
and args[0].has_force_unpack_var_sequence(tx)
|
||||
):
|
||||
assert not kwargs
|
||||
elif name == "extend" and self.is_mutable():
|
||||
if len(args) != 1 or kwargs:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
if not args[0].has_force_unpack_var_sequence(tx):
|
||||
msg = ConstantVariable.create(f"{type(args[0])} object is not iterable")
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
(arg,) = args
|
||||
arg.force_apply_to_var_sequence(
|
||||
tx, lambda item: self.call_method(tx, "append", [item], {})
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "insert" and self.is_mutable():
|
||||
assert not kwargs
|
||||
if kwargs or len(args) != 2:
|
||||
raise_args_mismatch(tx, name)
|
||||
idx, value = args
|
||||
if isinstance(idx, SymNodeVariable):
|
||||
const_idx = idx.evaluate_expr()
|
||||
@ -425,10 +440,23 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "pop" and self.is_mutable():
|
||||
assert not kwargs
|
||||
if kwargs or len(args) > 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
if len(self.items) == 0:
|
||||
msg = ConstantVariable.create("pop from empty list")
|
||||
raise_observed_exception(IndexError, tx, args=[msg])
|
||||
|
||||
if len(args):
|
||||
idx = args[0].as_python_constant()
|
||||
if idx > len(self.items):
|
||||
msg = ConstantVariable.create("pop index out of range")
|
||||
raise_observed_exception(IndexError, tx, args=[msg])
|
||||
tx.output.side_effects.mutation(self)
|
||||
return self.items.pop(*[a.as_python_constant() for a in args])
|
||||
elif name == "clear" and self.is_mutable():
|
||||
assert not kwargs and not args
|
||||
if args or kwargs:
|
||||
raise_observed_exception(TypeError, tx)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.clear()
|
||||
return ConstantVariable.create(None)
|
||||
@ -471,13 +499,13 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "copy":
|
||||
# List copy() doesn't have args and kwargs
|
||||
assert not kwargs
|
||||
assert not args
|
||||
if args or kwargs:
|
||||
raise_args_mismatch(tx, name)
|
||||
items = list(self.items)
|
||||
return self.modified(items, mutation_type=ValueMutationNew())
|
||||
elif name == "reverse" and self.is_mutable():
|
||||
assert not kwargs
|
||||
assert not args
|
||||
if args or kwargs:
|
||||
raise_args_mismatch(tx, name)
|
||||
self.items.reverse()
|
||||
tx.output.side_effects.mutation(self)
|
||||
return ConstantVariable.create(None)
|
||||
@ -506,28 +534,47 @@ class ListVariable(CommonListMethodsVariable):
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if (
|
||||
name == "__setitem__"
|
||||
and self.is_mutable()
|
||||
and args
|
||||
and args[0].is_python_constant()
|
||||
):
|
||||
assert not kwargs
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if name == "__setitem__" and self.is_mutable():
|
||||
if kwargs or len(args) != 2:
|
||||
raise_args_mismatch(tx, name)
|
||||
key, value = args
|
||||
|
||||
if not key.is_python_constant():
|
||||
# probably will graph-break
|
||||
super().call_method(tx, name, args, kwargs)
|
||||
|
||||
tx.output.side_effects.mutation(self)
|
||||
if isinstance(key, SliceVariable):
|
||||
if not value.has_force_unpack_var_sequence(tx):
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported conversion for slice assignment",
|
||||
context=f"call_method {self} {name} {args}",
|
||||
explanation=f"Missing dynamo support for converting {value} into a list for slice assignment.",
|
||||
hints=[*graph_break_hints.SUPPORTABLE],
|
||||
msg = ConstantVariable.create("can only assign an iterable")
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
key = key.as_python_constant()
|
||||
if key.step == 0:
|
||||
msg = ConstantVariable.create("slice step cannot be zero")
|
||||
raise_observed_exception(ValueError, tx, args=[msg])
|
||||
|
||||
value = value.force_unpack_var_sequence(tx)
|
||||
try:
|
||||
self.items[key] = value
|
||||
except Exception as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
|
||||
tx
|
||||
)
|
||||
else:
|
||||
self.items[key.as_python_constant()] = value
|
||||
if isinstance(key, SymNodeVariable):
|
||||
key = key.evaluate_expr()
|
||||
else:
|
||||
key = key.as_python_constant()
|
||||
|
||||
if key >= len(self.items) or key < -len(self.items):
|
||||
msg = ConstantVariable.create("list index out of range")
|
||||
raise_observed_exception(IndexError, tx, args=[msg])
|
||||
self.items[key] = value
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
if name == "sort" and self.is_mutable():
|
||||
|
Reference in New Issue
Block a user