[dynamo] allow symints in list.__setitem__ (#156197)

Fixes https://github.com/pytorch/pytorch/issues/155174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156197
Approved by: https://github.com/StrongerXi
This commit is contained in:
Isuru Fernando
2025-06-20 21:27:12 +00:00
committed by PyTorch MergeBot
parent 162ca185ff
commit 44a5f93462
2 changed files with 46 additions and 3 deletions

View File

@ -2978,6 +2978,26 @@ class GraphModule(torch.nn.Module):
opt_fn = torch.compile(fullgraph=True, backend="eager")(fn)
self.assertEqual(opt_fn(x, a, b), fn(x, a, b))
def test_list_setitem(self):
def fn(a: int):
some_array = [1, 2, 3]
some_array[a] = 5
return torch.ones(some_array)
opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
self.assertEqual(opt_fn(0), fn(0))
self.assertEqual(opt_fn(1), fn(1))
def test_list_setitem_slice(self):
def fn(a: int):
some_array = [1, 2, 3]
some_array[a : a + 1] = [5]
return torch.ones(some_array)
opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
self.assertEqual(opt_fn(0), fn(0))
self.assertEqual(opt_fn(1), fn(1))
def test_pow_int(self):
def fn(a, b):
return torch.pow(a, b)

View File

@ -418,13 +418,36 @@ class CommonListMethodsVariable(BaseListVariable):
name == "__setitem__"
and self.is_mutable()
and args
and args[0].is_python_constant()
and (
args[0].is_python_constant()
or isinstance(args[0], SymNodeVariable)
or (
isinstance(args[0], SliceVariable)
and all(
s.is_python_constant() or isinstance(s, SymNodeVariable)
for s in args[0].items
)
)
)
):
assert not kwargs
key, value = args
tx.output.side_effects.mutation(self)
if isinstance(key, SliceVariable):
self.items[key.as_python_constant()] = list(value.items)
if isinstance(key, SymNodeVariable):
self.items[key.evaluate_expr()] = value
elif isinstance(key, SliceVariable):
if key.is_python_constant():
self.items[key.as_python_constant()] = list(value.items)
else:
items = slice(
*[
s.evaluate_expr()
if isinstance(s, SymNodeVariable)
else s.as_python_constant()
for s in key.items
]
)
self.items[items] = list(value.items)
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)