[Dynamo] Fix torch.is_tensor and torch.overrides.is_tensor_like (#88704)

Fixes error from 7k github models: https://github.com/jansel/pytorch-jit-paritybench/blob/master/generated/test_arashwan_matrixnet.py

Error:
```
AssertionError: torch.* op returned non-Tensor bool call_function <function is_tensor at 0x7fca94d0faf0>

from user code:
   File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 749, in scatter
      return scatter_map(inputs)
   File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 741, in scatter_map
      assert not torch.is_tensor(obj), 'Tensors not supported in scatter.'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88704
Approved by: https://github.com/jansel
This commit is contained in:
Yanbo Liang
2022-11-14 22:45:50 +00:00
committed by PyTorch MergeBot
parent 3b33a2794e
commit 911a1349dd
2 changed files with 55 additions and 9 deletions

View File

@ -400,6 +400,23 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
def test_is_tensor2(self):
def fn(x):
if torch.is_tensor(x):
return x + 1
else:
return torch.ones([2, 3])
x1 = {"input": torch.rand(2, 3)}
x2 = torch.rand(2, 3)
ref1 = fn(x1)
ref2 = fn(x2)
opt_fn = torch._dynamo.optimize("eager")(fn)
res1 = opt_fn(x1)
res2 = opt_fn(x2)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
def test_numel(self): def test_numel(self):
def fn(a): def fn(a):
return a + a.numel() + torch.numel(a) return a + a.numel() + torch.numel(a)
@ -1244,6 +1261,32 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref0, res0)) self.assertTrue(same(ref0, res0))
self.assertTrue(same(ref1, res1)) self.assertTrue(same(ref1, res1))
def test_is_tensor_like2(self):
class MyTensor(object):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.max:
return torch.tensor(123)
return func(*args, **kwargs)
def fn(x):
if torch.overrides.is_tensor_like(x):
return torch.max(x)
else:
return torch.zeros(1)
x = MyTensor()
ref0 = fn(x)
ref1 = fn(4)
opt_fn = torch._dynamo.optimize("eager")(fn)
res0 = opt_fn(x)
res1 = opt_fn(4)
self.assertTrue(same(ref0, res0))
self.assertTrue(same(ref1, res1))
def test_version_ci(self): def test_version_ci(self):
# temporary test to check that the ci torch version is set correctly # temporary test to check that the ci torch version is set correctly
self.assertTrue(hasattr(torch, "_subclasses")) self.assertTrue(hasattr(torch, "_subclasses"))

View File

@ -163,8 +163,6 @@ class TorchVariable(VariableTracker):
torch.finfo, torch.finfo,
torch.iinfo, torch.iinfo,
torch.is_floating_point, torch.is_floating_point,
torch.is_tensor,
torch.overrides.is_tensor_like,
): ):
return True return True
return getattr(self.value, "__module__", None) == "math" return getattr(self.value, "__module__", None) == "math"
@ -177,9 +175,9 @@ class TorchVariable(VariableTracker):
DynamicShapeVariable, DynamicShapeVariable,
GradModeVariable, GradModeVariable,
TensorVariable, TensorVariable,
UserDefinedObjectVariable,
) )
# print("CALLING ON TORCH", self.value)
from .builder import wrap_fx_proxy from .builder import wrap_fx_proxy
constant_args = check_constant_args(args, kwargs) constant_args = check_constant_args(args, kwargs)
@ -206,21 +204,26 @@ class TorchVariable(VariableTracker):
return self._call_cross_entropy_loss(tx, args, kwargs, options) return self._call_cross_entropy_loss(tx, args, kwargs, options)
else: else:
unimplemented(f"construct nn.Module: {self.value.__name__}") unimplemented(f"construct nn.Module: {self.value.__name__}")
elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
assert len(args) == 1
if isinstance(args[0], TensorVariable) or (
self.value is torch.overrides.is_tensor_like
and isinstance(args[0], UserDefinedObjectVariable)
and hasattr(args[0].value, "__torch_function__")
):
return ConstantVariable(True, **options)
else:
return ConstantVariable(False, **options)
elif ( elif (
self.value self.value
in ( in (
torch.is_tensor,
torch.is_floating_point, torch.is_floating_point,
torch.is_complex, torch.is_complex,
torch.overrides.is_tensor_like,
torch.is_complex,
) )
and isinstance(args[0], TensorVariable) and isinstance(args[0], TensorVariable)
and args[0].dtype is not None and args[0].dtype is not None
): ):
if self.value in (torch.is_tensor, torch.overrides.is_tensor_like): if self.value is torch.is_floating_point:
return ConstantVariable(True, **options)
elif self.value is torch.is_floating_point:
return ConstantVariable(args[0].dtype.is_floating_point, **options) return ConstantVariable(args[0].dtype.is_floating_point, **options)
elif self.value is torch.is_complex: elif self.value is torch.is_complex:
return ConstantVariable(args[0].dtype.is_complex, **options) return ConstantVariable(args[0].dtype.is_complex, **options)