mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125. This PR: * Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this. * Fixes outputs across keepdim settings when reducing over ragged / batch dims. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173 Approved by: https://github.com/drisspg