mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] allow symints in list.__setitem__ (#156197)
Fixes https://github.com/pytorch/pytorch/issues/155174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156197 Approved by: https://github.com/StrongerXi
This commit is contained in:
committed by
PyTorch MergeBot
parent
162ca185ff
commit
44a5f93462
@ -2978,6 +2978,26 @@ class GraphModule(torch.nn.Module):
|
||||
opt_fn = torch.compile(fullgraph=True, backend="eager")(fn)
|
||||
self.assertEqual(opt_fn(x, a, b), fn(x, a, b))
|
||||
|
||||
def test_list_setitem(self):
|
||||
def fn(a: int):
|
||||
some_array = [1, 2, 3]
|
||||
some_array[a] = 5
|
||||
return torch.ones(some_array)
|
||||
|
||||
opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
|
||||
self.assertEqual(opt_fn(0), fn(0))
|
||||
self.assertEqual(opt_fn(1), fn(1))
|
||||
|
||||
def test_list_setitem_slice(self):
|
||||
def fn(a: int):
|
||||
some_array = [1, 2, 3]
|
||||
some_array[a : a + 1] = [5]
|
||||
return torch.ones(some_array)
|
||||
|
||||
opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
|
||||
self.assertEqual(opt_fn(0), fn(0))
|
||||
self.assertEqual(opt_fn(1), fn(1))
|
||||
|
||||
def test_pow_int(self):
|
||||
def fn(a, b):
|
||||
return torch.pow(a, b)
|
||||
|
@ -418,13 +418,36 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
name == "__setitem__"
|
||||
and self.is_mutable()
|
||||
and args
|
||||
and args[0].is_python_constant()
|
||||
and (
|
||||
args[0].is_python_constant()
|
||||
or isinstance(args[0], SymNodeVariable)
|
||||
or (
|
||||
isinstance(args[0], SliceVariable)
|
||||
and all(
|
||||
s.is_python_constant() or isinstance(s, SymNodeVariable)
|
||||
for s in args[0].items
|
||||
)
|
||||
)
|
||||
)
|
||||
):
|
||||
assert not kwargs
|
||||
key, value = args
|
||||
tx.output.side_effects.mutation(self)
|
||||
if isinstance(key, SliceVariable):
|
||||
self.items[key.as_python_constant()] = list(value.items)
|
||||
if isinstance(key, SymNodeVariable):
|
||||
self.items[key.evaluate_expr()] = value
|
||||
elif isinstance(key, SliceVariable):
|
||||
if key.is_python_constant():
|
||||
self.items[key.as_python_constant()] = list(value.items)
|
||||
else:
|
||||
items = slice(
|
||||
*[
|
||||
s.evaluate_expr()
|
||||
if isinstance(s, SymNodeVariable)
|
||||
else s.as_python_constant()
|
||||
for s in key.items
|
||||
]
|
||||
)
|
||||
self.items[items] = list(value.items)
|
||||
else:
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
|
Reference in New Issue
Block a user