Proper view support for jagged layout NestedTensor (#113279)

This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
This commit is contained in:
Joel Schlosser
2024-03-21 18:03:48 -04:00
committed by PyTorch MergeBot
parent bde22835c6
commit cd6bfc7965
17 changed files with 542 additions and 205 deletions

View File

@ -81,6 +81,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"_chunk_grad_outputs_efficient_attention", # returns a bool
"_fused_sdp_choice", # returns an int
"_print", # no return
"_nested_get_ragged_idx", # returns an int
]
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [