mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[list] Add list.__delitem__ (#156339)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156339 Approved by: https://github.com/zou3519 ghstack dependencies: #153969, #156148, #156242, #156270, #156271
This commit is contained in:
committed by
PyTorch MergeBot
parent
d74ccf4ffe
commit
eda0a9cc90
@ -2,7 +2,6 @@
|
||||
|
||||
# TODO: move set tests from test_functions.py/test_misc.py to this file
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
@ -169,17 +168,6 @@ class TupleTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertRaises(TypeError, p.__contains__)
|
||||
self.assertRaises(TypeError, p.__contains__, 1, 2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test___delitem__(self):
|
||||
p = self.thetype("abc")
|
||||
self.assertIsNone(p.__delitem__(1))
|
||||
self.assertEqual(p, self.thetype("ac"))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__delitem__)
|
||||
self.assertRaises(TypeError, p.__delitem__, 1, 2)
|
||||
|
||||
|
||||
class ListTests(TupleTests):
|
||||
# List methods
|
||||
@ -356,6 +344,20 @@ class ListTests(TupleTests):
|
||||
self.assertEqual(fn(x), x.sin())
|
||||
self.assertEqual(lst, ["a", "b"])
|
||||
|
||||
def test_binop_delitem_global_list(self):
|
||||
global lst
|
||||
lst = self.thetype(["a", "b", "c"])
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
global lst
|
||||
del lst[1]
|
||||
return x.sin()
|
||||
|
||||
x = torch.tensor(1.0)
|
||||
self.assertEqual(fn(x), x.sin())
|
||||
self.assertEqual(lst, ["a", "c"])
|
||||
|
||||
@make_dynamo_test
|
||||
def test___setitem__(self):
|
||||
p = self.thetype("abc")
|
||||
@ -370,6 +372,23 @@ class ListTests(TupleTests):
|
||||
self.assertRaises(TypeError, p.__setitem__, 1)
|
||||
self.assertRaises(TypeError, p.__setitem__, 1, 2, 3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test___delitem__(self):
|
||||
p = self.thetype("abcdef")
|
||||
self.assertIsNone(p.__delitem__(1))
|
||||
self.assertEqual(p, self.thetype("acdef"))
|
||||
|
||||
self.assertIsNone(p.__delitem__(slice(1, 3)))
|
||||
self.assertEqual(p, self.thetype("aef"))
|
||||
|
||||
# Slice step == 0
|
||||
self.assertRaises(ValueError, p.__delitem__, slice(1, 1, 0))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__delitem__)
|
||||
self.assertRaises(TypeError, p.__delitem__, 1.1)
|
||||
self.assertRaises(TypeError, p.__delitem__, 1, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -531,6 +531,33 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
else:
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__delitem__" and self.is_mutable():
|
||||
if kwargs or len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
tx.output.side_effects.mutation(self)
|
||||
if args[0].is_python_constant() and isinstance(
|
||||
args[0].as_python_constant(), (int, slice)
|
||||
):
|
||||
if isinstance(args[0], SymNodeVariable):
|
||||
idx = args[0].evaluate_expr()
|
||||
else:
|
||||
idx = args[0].as_python_constant()
|
||||
|
||||
try:
|
||||
self.items.__delitem__(idx)
|
||||
except (IndexError, ValueError) as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
else:
|
||||
msg = ConstantVariable.create(
|
||||
f"list indices must be integers or slices, not {args[0].python_type_name()}"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "copy":
|
||||
# List copy() doesn't have args and kwargs
|
||||
if args or kwargs:
|
||||
|
Reference in New Issue
Block a user