mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Proper view support for jagged layout NestedTensor (#113279)
This PR: * Introduces an ATen op for creating true jagged views from a dense values buffer * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)` * This ops is implemented on the Python side using torch.library so we can return a subclass instance * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer` * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()` * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view * Introduces an ATen op for accessing the `values` component of an NT via a view * `_nested_get_values(nt)` * **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively. * Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly * Similarly, avoid `buffer_from_jagged()`, preferring `values()` * Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack) With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling. Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922) Co-authored-by: voznesenskym <voznesenskym@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bde22835c6
commit
cd6bfc7965
@ -797,9 +797,10 @@ def signature_from_schema(
|
||||
or name.startswith("new_")
|
||||
or name.endswith("_like")
|
||||
)
|
||||
is_dummy_function = category_override == "dummy"
|
||||
|
||||
tensor_options_args: List[PythonArgument] = []
|
||||
if is_factory_function or is_like_or_new_function:
|
||||
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
|
||||
|
||||
def topt_default_init(name: str) -> Optional[str]:
|
||||
topt_args = func.arguments.tensor_options
|
||||
|
Reference in New Issue
Block a user