Files
pytorch/torch/autograd/_functions/utils.py
Justin Chu 524b78d4f6 [ONNX] Refactor torchscript based exporter (#161323)
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved.

- Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly
- Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused

@albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and
`torch/autograd/_functions/utils.py`? Thanks!

## BC Breaking
- **Deprecated members in `torch.onnx.verification` are removed**

Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323
Approved by: https://github.com/titaiwangms, https://github.com/angelayi
2025-09-02 16:10:30 +00:00

27 lines
753 B
Python

# mypy: allow-untyped-defs
def maybe_view(tensor, size, check_same_size=True):
if check_same_size and tensor.size() == size:
return tensor
return tensor.contiguous().view(size)
def maybe_unexpand(tensor, old_size, check_same_size=True):
if check_same_size and tensor.size() == old_size:
return tensor
num_unsqueezed = tensor.dim() - len(old_size)
expanded_dims = [
dim
for dim, (expanded, original) in enumerate(
zip(tensor.size()[num_unsqueezed:], old_size)
)
if expanded != original
]
for _ in range(num_unsqueezed):
tensor = tensor.sum(0, keepdim=False)
for dim in expanded_dims:
tensor = tensor.sum(dim, keepdim=True)
return tensor