[dynamo] Support is comparison for symnodes (#140754)

Fixes #109504.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140754
Approved by: https://github.com/williamwen42
This commit is contained in:
Ryan Guo
2024-11-18 10:11:27 -08:00
committed by PyTorch MergeBot
parent 175ba9fed6
commit 2da98d9757
2 changed files with 26 additions and 0 deletions

View File

@ -6338,6 +6338,30 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
res = opt_mod(x) res = opt_mod(x)
self.assertEqual(ref, res) self.assertEqual(ref, res)
def test_symnode_is_op(self):
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
def f(x, xs):
if x.size(0) is xs:
return x + 1
else:
return x * 2
t = torch.randn(2)
res = f(t, [1, 2])
self.assertEqual(t * 2, res)
def test_symnode_is_not_op(self):
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
def f(x, xs):
if x.size(0) is not xs:
return x + 1
else:
return x * 2
t = torch.randn(2)
res = f(t, [1, 2])
self.assertEqual(t + 1, res)
instantiate_parametrized_tests(ReproTests) instantiate_parametrized_tests(ReproTests)

View File

@ -69,6 +69,8 @@ supported_tensor_comparison_ops = {
"<=": operator.le, "<=": operator.le,
"==": operator.eq, "==": operator.eq,
"!=": operator.ne, "!=": operator.ne,
"is": operator.is_,
"is not": operator.is_not,
} }
# Ops that allow tensor <op> None # Ops that allow tensor <op> None
supported_const_comparison_ops = { supported_const_comparison_ops = {