mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[Dynamo] Fix torch.is_tensor and torch.overrides.is_tensor_like (#88704)
Fixes error from 7k github models: https://github.com/jansel/pytorch-jit-paritybench/blob/master/generated/test_arashwan_matrixnet.py Error: ``` AssertionError: torch.* op returned non-Tensor bool call_function <function is_tensor at 0x7fca94d0faf0> from user code: File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 749, in scatter return scatter_map(inputs) File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 741, in scatter_map assert not torch.is_tensor(obj), 'Tensors not supported in scatter.' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88704 Approved by: https://github.com/jansel
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							3b33a2794e
						
					
				
				
					commit
					911a1349dd
				
			@ -400,6 +400,23 @@ class MiscTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
 | 
			
		||||
        return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
 | 
			
		||||
 | 
			
		||||
    def test_is_tensor2(self):
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            if torch.is_tensor(x):
 | 
			
		||||
                return x + 1
 | 
			
		||||
            else:
 | 
			
		||||
                return torch.ones([2, 3])
 | 
			
		||||
 | 
			
		||||
        x1 = {"input": torch.rand(2, 3)}
 | 
			
		||||
        x2 = torch.rand(2, 3)
 | 
			
		||||
        ref1 = fn(x1)
 | 
			
		||||
        ref2 = fn(x2)
 | 
			
		||||
        opt_fn = torch._dynamo.optimize("eager")(fn)
 | 
			
		||||
        res1 = opt_fn(x1)
 | 
			
		||||
        res2 = opt_fn(x2)
 | 
			
		||||
        self.assertEqual(ref1, res1)
 | 
			
		||||
        self.assertEqual(ref2, res2)
 | 
			
		||||
 | 
			
		||||
    def test_numel(self):
 | 
			
		||||
        def fn(a):
 | 
			
		||||
            return a + a.numel() + torch.numel(a)
 | 
			
		||||
@ -1244,6 +1261,32 @@ class MiscTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        self.assertTrue(same(ref0, res0))
 | 
			
		||||
        self.assertTrue(same(ref1, res1))
 | 
			
		||||
 | 
			
		||||
    def test_is_tensor_like2(self):
 | 
			
		||||
        class MyTensor(object):
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def __torch_function__(cls, func, types, args=(), kwargs=None):
 | 
			
		||||
                if kwargs is None:
 | 
			
		||||
                    kwargs = {}
 | 
			
		||||
 | 
			
		||||
                if func is torch.max:
 | 
			
		||||
                    return torch.tensor(123)
 | 
			
		||||
                return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            if torch.overrides.is_tensor_like(x):
 | 
			
		||||
                return torch.max(x)
 | 
			
		||||
            else:
 | 
			
		||||
                return torch.zeros(1)
 | 
			
		||||
 | 
			
		||||
        x = MyTensor()
 | 
			
		||||
        ref0 = fn(x)
 | 
			
		||||
        ref1 = fn(4)
 | 
			
		||||
        opt_fn = torch._dynamo.optimize("eager")(fn)
 | 
			
		||||
        res0 = opt_fn(x)
 | 
			
		||||
        res1 = opt_fn(4)
 | 
			
		||||
        self.assertTrue(same(ref0, res0))
 | 
			
		||||
        self.assertTrue(same(ref1, res1))
 | 
			
		||||
 | 
			
		||||
    def test_version_ci(self):
 | 
			
		||||
        # temporary test to check that the ci torch version is set correctly
 | 
			
		||||
        self.assertTrue(hasattr(torch, "_subclasses"))
 | 
			
		||||
 | 
			
		||||
@ -163,8 +163,6 @@ class TorchVariable(VariableTracker):
 | 
			
		||||
            torch.finfo,
 | 
			
		||||
            torch.iinfo,
 | 
			
		||||
            torch.is_floating_point,
 | 
			
		||||
            torch.is_tensor,
 | 
			
		||||
            torch.overrides.is_tensor_like,
 | 
			
		||||
        ):
 | 
			
		||||
            return True
 | 
			
		||||
        return getattr(self.value, "__module__", None) == "math"
 | 
			
		||||
@ -177,9 +175,9 @@ class TorchVariable(VariableTracker):
 | 
			
		||||
            DynamicShapeVariable,
 | 
			
		||||
            GradModeVariable,
 | 
			
		||||
            TensorVariable,
 | 
			
		||||
            UserDefinedObjectVariable,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # print("CALLING ON TORCH", self.value)
 | 
			
		||||
        from .builder import wrap_fx_proxy
 | 
			
		||||
 | 
			
		||||
        constant_args = check_constant_args(args, kwargs)
 | 
			
		||||
@ -206,21 +204,26 @@ class TorchVariable(VariableTracker):
 | 
			
		||||
                return self._call_cross_entropy_loss(tx, args, kwargs, options)
 | 
			
		||||
            else:
 | 
			
		||||
                unimplemented(f"construct nn.Module: {self.value.__name__}")
 | 
			
		||||
        elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
 | 
			
		||||
            assert len(args) == 1
 | 
			
		||||
            if isinstance(args[0], TensorVariable) or (
 | 
			
		||||
                self.value is torch.overrides.is_tensor_like
 | 
			
		||||
                and isinstance(args[0], UserDefinedObjectVariable)
 | 
			
		||||
                and hasattr(args[0].value, "__torch_function__")
 | 
			
		||||
            ):
 | 
			
		||||
                return ConstantVariable(True, **options)
 | 
			
		||||
            else:
 | 
			
		||||
                return ConstantVariable(False, **options)
 | 
			
		||||
        elif (
 | 
			
		||||
            self.value
 | 
			
		||||
            in (
 | 
			
		||||
                torch.is_tensor,
 | 
			
		||||
                torch.is_floating_point,
 | 
			
		||||
                torch.is_complex,
 | 
			
		||||
                torch.overrides.is_tensor_like,
 | 
			
		||||
                torch.is_complex,
 | 
			
		||||
            )
 | 
			
		||||
            and isinstance(args[0], TensorVariable)
 | 
			
		||||
            and args[0].dtype is not None
 | 
			
		||||
        ):
 | 
			
		||||
            if self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
 | 
			
		||||
                return ConstantVariable(True, **options)
 | 
			
		||||
            elif self.value is torch.is_floating_point:
 | 
			
		||||
            if self.value is torch.is_floating_point:
 | 
			
		||||
                return ConstantVariable(args[0].dtype.is_floating_point, **options)
 | 
			
		||||
            elif self.value is torch.is_complex:
 | 
			
		||||
                return ConstantVariable(args[0].dtype.is_complex, **options)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user