Proper view support for jagged layout NestedTensor (#113279)

This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
This commit is contained in:
Joel Schlosser
2024-03-21 18:03:48 -04:00
committed by PyTorch MergeBot
parent bde22835c6
commit cd6bfc7965
17 changed files with 542 additions and 205 deletions

View File

@ -797,9 +797,10 @@ def signature_from_schema(
or name.startswith("new_")
or name.endswith("_like")
)
is_dummy_function = category_override == "dummy"
tensor_options_args: List[PythonArgument] = []
if is_factory_function or is_like_or_new_function:
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
def topt_default_init(name: str) -> Optional[str]:
topt_args = func.arguments.tensor_options