mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[hop][BE] add util diff_meta with prettier error message. (#142162)
The error message changes from: ```python -torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))] ``` to ```python +torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
9ced54a51a
commit
7111cd6ee0
@ -9,6 +9,7 @@ import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._ops import OperatorBase
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
|
||||
@ -481,6 +482,27 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
||||
return torch.select_copy(t, dim, 0)
|
||||
|
||||
|
||||
# Reports the difference between meta of two tensors in a string
|
||||
def diff_tensor_meta(
|
||||
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
||||
) -> List[str]:
|
||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
||||
|
||||
pair_diffs = []
|
||||
for meta_name in TensorMetadata._fields:
|
||||
if not check_grad and meta_name == "requires_grad":
|
||||
continue
|
||||
val1 = getattr(meta1, meta_name)
|
||||
val2 = getattr(meta2, meta_name)
|
||||
try:
|
||||
if val1 != val2:
|
||||
pair_diffs.append(f"'{meta_name}'")
|
||||
except GuardOnDataDependentSymNode as _:
|
||||
pair_diffs.append(f"'{meta_name}'")
|
||||
continue
|
||||
return pair_diffs
|
||||
|
||||
|
||||
# Note [lifted arg types in hop]
|
||||
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
||||
# This has implications for the types of lifted args for different dispatch keys:
|
||||
|
Reference in New Issue
Block a user