mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop][BE] add util diff_meta with prettier error message. (#142162)
The error message changes from: ```python -torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))] ``` to ```python +torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
9ced54a51a
commit
7111cd6ee0
@ -6957,6 +6957,23 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
opt_test(True, False, inp)
|
||||
|
||||
def test_cond_with_mismatched_output(self):
|
||||
def output_mismatch_test(x):
|
||||
def true_fn():
|
||||
return torch.concat([x, x])
|
||||
|
||||
def false_fn():
|
||||
return x.sin()
|
||||
|
||||
return torch.cond(x.sum() > 0, true_fn, false_fn)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||
output_mismatch_test(x)
|
||||
|
||||
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||
torch.compile(output_mismatch_test)(x)
|
||||
|
||||
def test_non_aliasing_util(self):
|
||||
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
|
||||
|
||||
|
Reference in New Issue
Block a user