[hop][BE] add util diff_meta with prettier error message. (#142162)

The error message changes from:
```python
-torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))]
```
to
```python
+torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2024-12-06 10:38:50 -08:00
committed by PyTorch MergeBot
parent 9ced54a51a
commit 7111cd6ee0
3 changed files with 74 additions and 18 deletions

View File

@ -9,6 +9,7 @@ import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._ops import OperatorBase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.shape_prop import TensorMetadata
from torch.multiprocessing.reductions import StorageWeakRef
@ -481,6 +482,27 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
return torch.select_copy(t, dim, 0)
# Reports the difference between meta of two tensors in a string
def diff_tensor_meta(
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
) -> List[str]:
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
pair_diffs = []
for meta_name in TensorMetadata._fields:
if not check_grad and meta_name == "requires_grad":
continue
val1 = getattr(meta1, meta_name)
val2 = getattr(meta2, meta_name)
try:
if val1 != val2:
pair_diffs.append(f"'{meta_name}'")
except GuardOnDataDependentSymNode as _:
pair_diffs.append(f"'{meta_name}'")
continue
return pair_diffs
# Note [lifted arg types in hop]
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
# This has implications for the types of lifted args for different dispatch keys: