mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support is
comparison for symnodes (#140754)
Fixes #109504. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140754 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
175ba9fed6
commit
2da98d9757
@ -6338,6 +6338,30 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
res = opt_mod(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_symnode_is_op(self):
|
||||
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
|
||||
def f(x, xs):
|
||||
if x.size(0) is xs:
|
||||
return x + 1
|
||||
else:
|
||||
return x * 2
|
||||
|
||||
t = torch.randn(2)
|
||||
res = f(t, [1, 2])
|
||||
self.assertEqual(t * 2, res)
|
||||
|
||||
def test_symnode_is_not_op(self):
|
||||
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
|
||||
def f(x, xs):
|
||||
if x.size(0) is not xs:
|
||||
return x + 1
|
||||
else:
|
||||
return x * 2
|
||||
|
||||
t = torch.randn(2)
|
||||
res = f(t, [1, 2])
|
||||
self.assertEqual(t + 1, res)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
@ -69,6 +69,8 @@ supported_tensor_comparison_ops = {
|
||||
"<=": operator.le,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
"is": operator.is_,
|
||||
"is not": operator.is_not,
|
||||
}
|
||||
# Ops that allow tensor <op> None
|
||||
supported_const_comparison_ops = {
|
||||
|
Reference in New Issue
Block a user