From eda0a9cc90b9a63127a49d617329f98b6404e90d Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Sat, 5 Jul 2025 15:54:47 -0300 Subject: [PATCH] [list] Add list.__delitem__ (#156339) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156339 Approved by: https://github.com/zou3519 ghstack dependencies: #153969, #156148, #156242, #156270, #156271 --- test/dynamo/test_list.py | 43 +++++++++++++------ ...CPython313-test_list-ListTest.test_delitem | 0 ...Python313-test_list-ListTest.test_delslice | 0 ...13-test_list-ListTest.test_extendedslicing | 0 ...st_list-ListTest.test_list_resize_overflow | 0 torch/_dynamo/variables/lists.py | 27 ++++++++++++ 6 files changed, 58 insertions(+), 12 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delitem delete mode 100644 test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delslice delete mode 100644 test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extendedslicing delete mode 100644 test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_resize_overflow diff --git a/test/dynamo/test_list.py b/test/dynamo/test_list.py index e6cde5dfaf42..60c799d0b6a4 100644 --- a/test/dynamo/test_list.py +++ b/test/dynamo/test_list.py @@ -2,7 +2,6 @@ # TODO: move set tests from test_functions.py/test_misc.py to this file -import unittest import torch import torch._dynamo.test_case @@ -169,17 +168,6 @@ class TupleTests(torch._dynamo.test_case.TestCase): self.assertRaises(TypeError, p.__contains__) self.assertRaises(TypeError, p.__contains__, 1, 2) - @unittest.expectedFailure - @make_dynamo_test - def test___delitem__(self): - p = self.thetype("abc") - self.assertIsNone(p.__delitem__(1)) - self.assertEqual(p, self.thetype("ac")) - - # Wrong number of arguments - self.assertRaises(TypeError, p.__delitem__) - self.assertRaises(TypeError, p.__delitem__, 1, 2) - class ListTests(TupleTests): # List methods @@ -356,6 +344,20 @@ class ListTests(TupleTests): self.assertEqual(fn(x), x.sin()) self.assertEqual(lst, ["a", "b"]) + def test_binop_delitem_global_list(self): + global lst + lst = self.thetype(["a", "b", "c"]) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + global lst + del lst[1] + return x.sin() + + x = torch.tensor(1.0) + self.assertEqual(fn(x), x.sin()) + self.assertEqual(lst, ["a", "c"]) + @make_dynamo_test def test___setitem__(self): p = self.thetype("abc") @@ -370,6 +372,23 @@ class ListTests(TupleTests): self.assertRaises(TypeError, p.__setitem__, 1) self.assertRaises(TypeError, p.__setitem__, 1, 2, 3) + @make_dynamo_test + def test___delitem__(self): + p = self.thetype("abcdef") + self.assertIsNone(p.__delitem__(1)) + self.assertEqual(p, self.thetype("acdef")) + + self.assertIsNone(p.__delitem__(slice(1, 3))) + self.assertEqual(p, self.thetype("aef")) + + # Slice step == 0 + self.assertRaises(ValueError, p.__delitem__, slice(1, 1, 0)) + + # Wrong number of arguments + self.assertRaises(TypeError, p.__delitem__) + self.assertRaises(TypeError, p.__delitem__, 1.1) + self.assertRaises(TypeError, p.__delitem__, 1, 2) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delitem b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delitem deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delslice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delslice deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extendedslicing b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extendedslicing deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_resize_overflow b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_resize_overflow deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index f097ebf69763..93547c79e956 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -531,6 +531,33 @@ class CommonListMethodsVariable(BaseListVariable): else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) + elif name == "__delitem__" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch(tx, name) + + tx.output.side_effects.mutation(self) + if args[0].is_python_constant() and isinstance( + args[0].as_python_constant(), (int, slice) + ): + if isinstance(args[0], SymNodeVariable): + idx = args[0].evaluate_expr() + else: + idx = args[0].as_python_constant() + + try: + self.items.__delitem__(idx) + except (IndexError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + else: + msg = ConstantVariable.create( + f"list indices must be integers or slices, not {args[0].python_type_name()}" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + return ConstantVariable.create(None) elif name == "copy": # List copy() doesn't have args and kwargs if args or kwargs: