[hop][BE] add util diff_meta with prettier error message. (#142162)

The error message changes from:
```python
-torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))]
```
to
```python
+torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2024-12-06 10:38:50 -08:00
committed by PyTorch MergeBot
parent 9ced54a51a
commit 7111cd6ee0
3 changed files with 74 additions and 18 deletions

View File

@ -6957,6 +6957,23 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(AssertionError):
opt_test(True, False, inp)
def test_cond_with_mismatched_output(self):
def output_mismatch_test(x):
def true_fn():
return torch.concat([x, x])
def false_fn():
return x.sin()
return torch.cond(x.sum() > 0, true_fn, false_fn)
x = torch.randn(2, 3)
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
output_mismatch_test(x)
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
torch.compile(output_mismatch_test)(x)
def test_non_aliasing_util(self):
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing