Files
pytorch/torch/nested
Joel Schlosser e803a3d83a Fix reductions for NJTs with ragged_idx != 1 (#142173)
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125.

This PR:
* Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this.
* Fixes outputs across keepdim settings when reducing over ragged / batch dims.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173
Approved by: https://github.com/drisspg
2024-12-06 01:23:17 +00:00
..