[list] Add list.__mul__ and list.__imul__ (#156271)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156271
Approved by: https://github.com/zou3519
ghstack dependencies: #153969, #156148, #156242, #156270
This commit is contained in:
Guilherme Leobas
2025-07-05 15:54:46 -03:00
committed by PyTorch MergeBot
parent 689fba032d
commit d74ccf4ffe
2 changed files with 64 additions and 0 deletions

View File

@ -59,6 +59,20 @@ class TupleTests(torch._dynamo.test_case.TestCase):
# Wrong number of arguments
self.assertRaises(TypeError, p.index)
@make_dynamo_test
def test_binop_imul(self):
p = self.thetype([1, 2, 3])
r = p.__mul__(2)
self.assertIsInstance(r, self.thetype)
self.assertEqual(r, self.thetype([1, 2, 3, 1, 2, 3]))
self.assertEqual(p, self.thetype([1, 2, 3]))
# Wrong number of arguments
self.assertRaises(TypeError, p.__mul__)
# can only multiply list by an integer
self.assertRaises(TypeError, p.__mul__, 2.2)
@make_dynamo_test
def test_binop_add(self):
p, q = map(self.thetype, ["abc", "bcd"])
@ -276,6 +290,39 @@ class ListTests(TupleTests):
self.assertIsNone(p.sort())
self.assertEqual(p, self.thetype("abcd"))
@make_dynamo_test
def test_binop_imul(self):
p = self.thetype([1, 2, 3])
r = p.__imul__(2)
self.assertIsInstance(r, self.thetype)
self.assertEqual(r, self.thetype([1, 2, 3, 1, 2, 3]))
self.assertEqual(p, self.thetype([1, 2, 3, 1, 2, 3]))
p = self.thetype("ab")
p *= 2
self.assertEqual(p, self.thetype("abab"))
# Wrong number of arguments
self.assertRaises(TypeError, p.__imul__)
# can only multiply list by an integer
self.assertRaises(TypeError, p.__imul__, 2.2)
def test_binop_imul_global_list(self):
global lst
lst = self.thetype(["a", "b"])
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
global lst
lst *= 2
lst.__imul__(3)
return x.sin()
x = torch.tensor(1.0)
self.assertEqual(fn(x), x.sin())
self.assertEqual(lst, ["a", "b"] * 6)
@make_dynamo_test
def test_binop_iadd(self):
p, q = map(self.thetype, ["abc", "bcd"])

View File

@ -206,6 +206,23 @@ class BaseListVariable(VariableTracker):
else:
self.items += args[0].items
return self
elif name in ("__mul__", "__imul__"):
if kwargs or len(args) != 1:
raise_args_mismatch(tx, name)
if not (args[0].is_python_constant() and args[0].python_type() is int):
msg = ConstantVariable.create(
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
val = args[0].as_python_constant()
if name == "__mul__":
return type(self)(self.items * val, source=self.source)
else:
self.items *= val
return self
elif name in cmp_name_to_op_mapping:
if len(args) != 1:
raise_args_mismatch(tx, name)