mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support is
comparison for symnodes (#140754)
Fixes #109504. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140754 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
175ba9fed6
commit
2da98d9757
@ -6338,6 +6338,30 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||||||
res = opt_mod(x)
|
res = opt_mod(x)
|
||||||
self.assertEqual(ref, res)
|
self.assertEqual(ref, res)
|
||||||
|
|
||||||
|
def test_symnode_is_op(self):
|
||||||
|
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
|
||||||
|
def f(x, xs):
|
||||||
|
if x.size(0) is xs:
|
||||||
|
return x + 1
|
||||||
|
else:
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
t = torch.randn(2)
|
||||||
|
res = f(t, [1, 2])
|
||||||
|
self.assertEqual(t * 2, res)
|
||||||
|
|
||||||
|
def test_symnode_is_not_op(self):
|
||||||
|
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
|
||||||
|
def f(x, xs):
|
||||||
|
if x.size(0) is not xs:
|
||||||
|
return x + 1
|
||||||
|
else:
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
t = torch.randn(2)
|
||||||
|
res = f(t, [1, 2])
|
||||||
|
self.assertEqual(t + 1, res)
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(ReproTests)
|
instantiate_parametrized_tests(ReproTests)
|
||||||
|
|
||||||
|
@ -69,6 +69,8 @@ supported_tensor_comparison_ops = {
|
|||||||
"<=": operator.le,
|
"<=": operator.le,
|
||||||
"==": operator.eq,
|
"==": operator.eq,
|
||||||
"!=": operator.ne,
|
"!=": operator.ne,
|
||||||
|
"is": operator.is_,
|
||||||
|
"is not": operator.is_not,
|
||||||
}
|
}
|
||||||
# Ops that allow tensor <op> None
|
# Ops that allow tensor <op> None
|
||||||
supported_const_comparison_ops = {
|
supported_const_comparison_ops = {
|
||||||
|
Reference in New Issue
Block a user