mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[list] Add list.__mul__ and list.__imul__ (#156271)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156271 Approved by: https://github.com/zou3519 ghstack dependencies: #153969, #156148, #156242, #156270
This commit is contained in:
committed by
PyTorch MergeBot
parent
689fba032d
commit
d74ccf4ffe
@ -59,6 +59,20 @@ class TupleTests(torch._dynamo.test_case.TestCase):
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.index)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_imul(self):
|
||||
p = self.thetype([1, 2, 3])
|
||||
r = p.__mul__(2)
|
||||
self.assertIsInstance(r, self.thetype)
|
||||
self.assertEqual(r, self.thetype([1, 2, 3, 1, 2, 3]))
|
||||
self.assertEqual(p, self.thetype([1, 2, 3]))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__mul__)
|
||||
|
||||
# can only multiply list by an integer
|
||||
self.assertRaises(TypeError, p.__mul__, 2.2)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_add(self):
|
||||
p, q = map(self.thetype, ["abc", "bcd"])
|
||||
@ -276,6 +290,39 @@ class ListTests(TupleTests):
|
||||
self.assertIsNone(p.sort())
|
||||
self.assertEqual(p, self.thetype("abcd"))
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_imul(self):
|
||||
p = self.thetype([1, 2, 3])
|
||||
r = p.__imul__(2)
|
||||
self.assertIsInstance(r, self.thetype)
|
||||
self.assertEqual(r, self.thetype([1, 2, 3, 1, 2, 3]))
|
||||
self.assertEqual(p, self.thetype([1, 2, 3, 1, 2, 3]))
|
||||
|
||||
p = self.thetype("ab")
|
||||
p *= 2
|
||||
self.assertEqual(p, self.thetype("abab"))
|
||||
|
||||
# Wrong number of arguments
|
||||
self.assertRaises(TypeError, p.__imul__)
|
||||
|
||||
# can only multiply list by an integer
|
||||
self.assertRaises(TypeError, p.__imul__, 2.2)
|
||||
|
||||
def test_binop_imul_global_list(self):
|
||||
global lst
|
||||
lst = self.thetype(["a", "b"])
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
global lst
|
||||
lst *= 2
|
||||
lst.__imul__(3)
|
||||
return x.sin()
|
||||
|
||||
x = torch.tensor(1.0)
|
||||
self.assertEqual(fn(x), x.sin())
|
||||
self.assertEqual(lst, ["a", "b"] * 6)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_iadd(self):
|
||||
p, q = map(self.thetype, ["abc", "bcd"])
|
||||
|
@ -206,6 +206,23 @@ class BaseListVariable(VariableTracker):
|
||||
else:
|
||||
self.items += args[0].items
|
||||
return self
|
||||
elif name in ("__mul__", "__imul__"):
|
||||
if kwargs or len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
if not (args[0].is_python_constant() and args[0].python_type() is int):
|
||||
msg = ConstantVariable.create(
|
||||
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
val = args[0].as_python_constant()
|
||||
|
||||
if name == "__mul__":
|
||||
return type(self)(self.items * val, source=self.source)
|
||||
else:
|
||||
self.items *= val
|
||||
return self
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
Reference in New Issue
Block a user