mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved. - Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly - Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused @albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and `torch/autograd/_functions/utils.py`? Thanks! ## BC Breaking - **Deprecated members in `torch.onnx.verification` are removed** Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323 Approved by: https://github.com/titaiwangms, https://github.com/angelayi
27 lines
753 B
Python
27 lines
753 B
Python
# mypy: allow-untyped-defs
|
|
|
|
|
|
def maybe_view(tensor, size, check_same_size=True):
|
|
if check_same_size and tensor.size() == size:
|
|
return tensor
|
|
return tensor.contiguous().view(size)
|
|
|
|
|
|
def maybe_unexpand(tensor, old_size, check_same_size=True):
|
|
if check_same_size and tensor.size() == old_size:
|
|
return tensor
|
|
num_unsqueezed = tensor.dim() - len(old_size)
|
|
expanded_dims = [
|
|
dim
|
|
for dim, (expanded, original) in enumerate(
|
|
zip(tensor.size()[num_unsqueezed:], old_size)
|
|
)
|
|
if expanded != original
|
|
]
|
|
|
|
for _ in range(num_unsqueezed):
|
|
tensor = tensor.sum(0, keepdim=False)
|
|
for dim in expanded_dims:
|
|
tensor = tensor.sum(dim, keepdim=True)
|
|
return tensor
|