mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #137512 Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this. At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137125 Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer