Commit Graph

53 Commits

Author SHA1 Message Date
5dc4f652bc Backward support for unbind() with NJT (#128032)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032
Approved by: https://github.com/soulitzer
2024-06-18 20:29:00 +00:00
2a41fc0390 Short-term fix to preserve NJT metadata cache in torch.compile (#122836)
Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122836
Approved by: https://github.com/soulitzer
ghstack dependencies: #127007, #128057
2024-06-17 15:25:09 +00:00
67e6c76a18 Support apply_(callable) sugar for CPU NJTs (#125416)
Example:
```python
nt.apply_(lambda x: x * 2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125416
Approved by: https://github.com/soulitzer
2024-06-12 20:30:57 +00:00
ec1fdda196 Fix jagged NT softmax semantics (#119459)
Before: `softmax` definition uses `jagged_unary_pointwise()` (wrong)
After: `softmax` impl adjusts the `dim` arg to account for the difference in dimensionality between the outer NT and the NT's `_values`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119459
Approved by: https://github.com/soulitzer
2024-06-12 19:12:03 +00:00
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
c1a43a69e4 [NestedTensor] Add error checks for unbind operator coverage when ragged_idx != 1 (#128058)
Summary:
Add the following error checks for the `unbind` operator on `NestedTensor`s when `ragged_idx != 1`:

- The current implementation allows the creation of `NestedTensor` instances from the class definition with an `offsets` tensor that applies to a dimension other than the jagged dimension. This diff ensures that `unbind` fails when the `offsets` exceed the length of the jagged dimension.

Test Plan:
Added the following unit tests:

`test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu` verifies that `unbind` fails when there is a mismatch between the offsets and the jagged dimension, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

Reviewed By: davidberard98

Differential Revision: D57989082

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128058
Approved by: https://github.com/davidberard98
2024-06-06 01:56:12 +00:00
7c3740d388 [NestedTensor] Extend coverage for unbind when ragged_idx != 1 (#127493)
Summary:
Extend coverage for the `NestedTensor` `unbind` operator to cases in which `ragged_idx != 1`.

Currently, the `unbind` operator in the `NestedTensor` class splits a tensor along the 0-th dimension, where the `ragged_idx` property, which controls the jagged dimension upon which `unbind` splits, is 1. This diff extends support for `ragged_idx != 1` in `NestedTensor`s, allowing `unbind` to split a tensor along a jagged dimension greater than 0 for `NestedTensor`s with and without the `lengths` property.

Test Plan:
Added the following unit tests:

`test_unbind_ragged_idx_equals_2_cpu`, `test_unbind_ragged_idx_equals_3_cpu`, and `test_unbind_ragged_idx_equals_last_dim_cpu` verify that `unbind` works for all jagged dimensions greater than 1, for `NestedTensor`s without `lengths`.
```
test_unbind_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_ragged_idx_equals_last_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_cpu` and `test_unbind_with_lengths_ragged_idx_equals_1_cpu` verify that `unbind` works when the jagged dimension is 1, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_with_lengths_ragged_idx_equals_1_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_2_cpu` and `test_unbind_with_lengths_ragged_idx_equals_3_cpu` verify that `unbind` works when the jagged dimension is greater than 1, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_with_lengths_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_0_cpu` verifies that `unbind` fails when the jagged dimension is 0 (the batch dimension), for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_0_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu` verifies that `unbind` fails when there is a mismatch between the offsets and the jagged dimension, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_wrong_lengths_cpu` verifies that `unbind` fails when the lengths exceed the limitations set by offsets, for `NestedTensor`s with `lengths`.

```
test_unbind_with_wrong_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

Differential Revision: D57942686

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127493
Approved by: https://github.com/davidberard98
2024-06-03 17:46:12 +00:00
82edc8b5d5 [NT] Make NestedTensor register as having symbolic sizes/strides (#124687)
Fixes #123698

This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors.

1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True`
2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts.
3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT.

Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt.

Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124687
Approved by: https://github.com/albanD
2024-05-13 16:50:25 +00:00
7ffa5558ee Revert "[FX] Update type hints in torch.fx._compatibility.py (#125469)"
This reverts commit 235b4d6ec22ddac35b2e47b7e871ef10538d4aee.

Reverted https://github.com/pytorch/pytorch/pull/125469 on behalf of https://github.com/izaitsevfb due to breaks pyre in dependent projects (internal: see D56986361) ([comment](https://github.com/pytorch/pytorch/pull/125469#issuecomment-2096665396))
2024-05-06 18:36:43 +00:00
235b4d6ec2 [FX] Update type hints in torch.fx._compatibility.py (#125469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125469
Approved by: https://github.com/Skylion007
ghstack dependencies: #125468
2024-05-05 19:30:22 +00:00
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
638b003cb7 [NJT] .to() properly updates device of offsets (#122797)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122797
Approved by: https://github.com/jbschlosser
2024-04-02 16:07:27 +00:00
4290a57e9c Revert "[NJT] .to() properly updates device of offsets (#122797)"
This reverts commit 3e7fd45b409966440c54f5e370885b4b2a388a01.

Reverted https://github.com/pytorch/pytorch/pull/122797 on behalf of https://github.com/jeffdaily due to Sorry for reverting your change but it is failing CUDA and ROCm jobs in trunk. Please help take a look and reland the change ([comment](https://github.com/pytorch/pytorch/pull/122797#issuecomment-2025473181))
2024-03-28 15:17:45 +00:00
3e7fd45b40 [NJT] .to() properly updates device of offsets (#122797)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122797
Approved by: https://github.com/jbschlosser
2024-03-28 00:56:23 +00:00
6767c04fde Forward fix for broken internal tests related to NJT view dummy (#122704)
(internal link) [example test breakage](https://www.internalfb.com/intern/test/562950061753019?ref_report_id=0)

Symptom: `type stub not overridden` for SymInt. The global NJT dummy relies on `SymInt.__mul__()` in its constructor. Lazily constructing the dummy avoids the race.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122704
Approved by: https://github.com/soulitzer
2024-03-26 21:22:12 +00:00
cd6bfc7965 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-22 02:12:36 +00:00
224beecee6 Revert "Proper view support for jagged layout NestedTensor (#113279)"
This reverts commit 5855c490f09a028bfdfefea8b93c9833eb55dc5c.

Reverted https://github.com/pytorch/pytorch/pull/113279 on behalf of https://github.com/jbschlosser due to Need to fix BC thing ([comment](https://github.com/pytorch/pytorch/pull/113279#issuecomment-2013899762))
2024-03-21 22:03:01 +00:00
5855c490f0 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-20 23:45:34 +00:00
0e604becc5 [NJT] support chunk on batch dim (#119713)
- support chunk op on batch dim
- support empty_like op
- add tests for the like ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119713
Approved by: https://github.com/jbschlosser
2024-03-05 17:57:50 +00:00
8994f2367d Revert "Fix jagged NT softmax semantics (#119459)"
This reverts commit 6adadbaf7943f760ea2375619b1783020b69d4e6.

Reverted https://github.com/pytorch/pytorch/pull/119459 on behalf of https://github.com/malfet due to broke dynamo, see https://github.com/pytorch/pytorch/actions/runs/7835402753/job/21386634602 ([comment](https://github.com/pytorch/pytorch/pull/119459#issuecomment-1935246413))
2024-02-09 02:31:49 +00:00
6adadbaf79 Fix jagged NT softmax semantics (#119459)
Before: `softmax` definition uses `jagged_unary_pointwise()` (wrong)
After: `softmax` impl adjusts the `dim` arg to account for the difference in dimensionality between the outer NT and the NT's `_values`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119459
Approved by: https://github.com/soulitzer
2024-02-08 20:13:12 +00:00
278a0e1600 [NestedTensor] Support binary pointwise ops with >2 inputs (if inputs are non-tensors) (#119419)
It should usually be safe to run pointwise binary ops with >2 inputs. e.g. threshold_backward(tensor, tensor, scalar): we just operate on the values of the nested tensors, and pass in the other args as-is.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119419
Approved by: https://github.com/soulitzer
2024-02-08 20:06:40 +00:00
460950d3aa [Nested Tensor] Support ragged_idx != 1 on aten::is_same_size, aten::_to_copy (#118442)
is_same_size is needed internally; `_to_copy` should be easy because it doesn't support new layouts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118442
Approved by: https://github.com/cpuhrsch
2024-01-30 01:32:51 +00:00
2842d3c9d3 [Nested Tensor] view: basic support for ragged_idx != 1 and _unsafe_view (#118317)
Uses case: `_unsafe_view` is used in aot_autograd to create a view that doesn't register as a view:

eebe7e1d37/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py (L470-L476)

If a transposed nested tensor (i.e. NT with ragged_idx != 1) encounters this code path, it previously would fail for two reasons: 1) because `_unsafe_view` isn't registered, and 2) because ragged_idx != 1 is not supported. This PR adds support for `_unsafe_view` (completely reusing the implementation of `view`; this just registers `_unsafe_view` as another op using the same implementation). It also adds support for ragged_idx != 1, but only for trivial cases where inp._size == size (the use case used by aot_autograd).

Tests: verify that the result of `_unsafe_view` doesn't have a `_base`, and that simple views on transposed NTs work.

Differential Revision: [D53096814](https://our.internmc.facebook.com/intern/diff/D53096814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118317
Approved by: https://github.com/soulitzer
2024-01-26 17:29:37 +00:00
52c5803088 [NestedTensor] Support ragged_idx != 1 in pointwise ops (#118157)
This PR allows pointwise ops to operate on tensors with ragged_idx != 1. It does this by passing the ragged_idx metadata into the construction of the returned NestedTensor when computing pointwise ops. The assumption is that: pointwise ops can operate directly on the values tensors, and the resulting tensor should have all the same metadata properties as the input tensors. For binary ops, a test is added to verify that adding two tensors with different ragged_idx cannot be added.

Previously:
* unary pointwise ops would error out when performed on nested tensors with ragged_idx != 1
* binary pointwise ops would produce tensors with nonsense shapes

Differential Revision: [D53032641](https://our.internmc.facebook.com/intern/diff/D53032641)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118157
Approved by: https://github.com/jbschlosser
2024-01-25 23:34:15 +00:00
f70aeb4ffd Fix backward for reshape() on jagged layout NT (#117137)
Provides symbolic C++-side `reshape_as()` / `reshape()` decomps for jagged layout NTs to make the backwards pass work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117137
Approved by: https://github.com/soulitzer
2024-01-10 23:35:07 +00:00
0b0c76bace Support squeeze.dim for jagged NT (#116891)
As title. Needed for `rev_view_func()` of `unsqueeze()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116891
Approved by: https://github.com/soulitzer
ghstack dependencies: #115894, #116512
2024-01-06 01:00:53 +00:00
ea3a5f8ddc Add chunk for jagged layout NT (#115842)
Nice to have for the [SDPA tutorial](https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115842
Approved by: https://github.com/soulitzer
ghstack dependencies: #115192, #116111
2023-12-20 20:13:20 +00:00
1474eb5f29 Fix jagged composite impl of flatten() (#115192)
Need to handle this in `NestedTensor.__torch_function__()` since it's CompositeImplicit
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115192
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
2023-12-19 19:15:21 +00:00
bf62511e07 Reshape decomposition for jagged layout NT (#115191)
No more segfault from using `reshape()` on jagged NT :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115191
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
2023-12-18 22:34:41 +00:00
6fee208064 Handle -1 in jagged layout NT view ops (#115843)
Allows for inheriting the ragged and batch dims via -1:
```python
nt.view(-1, -1, D)
nt.expand(B, -1, D)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115843
Approved by: https://github.com/soulitzer
ghstack dependencies: #115636
2023-12-15 00:42:47 +00:00
41b1919208 [nested_tensor]Python subclass NT overhead improvement (2/n): avoid getting from WeakTensorKeyDictionary twice during __init__ (#115450)
Summary:
Most NT operations end with creating a new NestedTensor, which is time-consuming. Trying to reduce overhead during the NestedTensor creation.

The ops return a new NestedTensor with the same offsets, so "tensor not in _tensor_symint_registry" would be false in most case. The "in" (__contain__) function takes ~8 us. If we use the "get" directly, then we save a few us for most NT operations.

Test Plan:
Before:
get_tensor_symint take 15us
https://pxl.cl/3XF83
After
get_tensor_symint take 10us
https://pxl.cl/3XFc9

Differential Revision: D51992836

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115450
Approved by: https://github.com/soulitzer
2023-12-09 03:12:31 +00:00
e071d6a9eb [Nested tensor]avoid using shape in python subclass NT, use _size instead (#115371)
Summary:
calling tensor.shape will call torch_dispatch which adds more overhead.

Testing overhead difference in "NT + NT" operation:
**Before:**
the add operation takes ~300us
{F1167963824}
**After:**
the add operation takes ~200us
 {F1167964056}

Test Plan: unit tests in test_nestedtensor

Differential Revision: D51949135

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115371
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
2023-12-08 02:08:36 +00:00
3b01f30b20 Prevent invalid pointwise ops on jagged with transposed ragged dim (#115190)
TODO: tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115190
Approved by: https://github.com/soulitzer, https://github.com/ani300
2023-12-08 00:54:03 +00:00
c99db5617a Introduce general metadata cache to jagged layout NestedTensor (#115212)
Slight refactor to:
* lazily compute min / max seq_len used for flash. this avoids unnecessary graph breaks / specialization when we're not accessing these
* store min / max seq_len in a general `metadata_cache`. condensing these should make it easier to avoid specializing on these and others we may add in the future
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115212
Approved by: https://github.com/soulitzer, https://github.com/ani300
ghstack dependencies: #114311
2023-12-06 19:40:35 +00:00
1dc4588c6a Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114164
Approved by: https://github.com/jbschlosser
2023-12-05 06:33:45 +00:00
5cfda9b7f8 Revert "Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)"
This reverts commit aafa8233a4a1f336014cb122d16941e5b593706c.

Reverted https://github.com/pytorch/pytorch/pull/114164 on behalf of https://github.com/malfet due to Broke ROCM, see aafa8233a4 ([comment](https://github.com/pytorch/pytorch/pull/114164#issuecomment-1839798986))
2023-12-05 00:35:20 +00:00
aafa8233a4 Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114164
Approved by: https://github.com/jbschlosser
2023-12-04 21:54:02 +00:00
2a8a7425be Fix to wrap jagged dims for split() / split_with_sizes() (#113591)
Still need OpInfo-style tests to catch things like this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113591
Approved by: https://github.com/soulitzer
2023-11-14 19:36:08 +00:00
ea39cc34f9 Refactor NestedTensor subclass to remove ragged_size from constructor (#113491)
This PR removes the need for passing `ragged_size` into the `NestedTensor` constructor. This was an artifact of fake-ification, where sometimes we needed the NT to have a symbolic singleton symint shape for the ragged dimension. The new way of achieving this is to also store mappings between fake / functional tensors -> symbolic symints in the ragged structure registry. Now the `NestedTensor` constructor can just query this registry for the `ragged_size`.

Old: `NestedTensor(values, offsets, *, ragged_size=None, **kwargs)`
New: `NestedTensor(values, offsets, **kwargs)`

This makes it possible to have a `_nested_view_from_values_offsets(values, offsets)` without needing to pass a `ragged_size`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113491
Approved by: https://github.com/ezyang, https://github.com/soulitzer
2023-11-14 19:32:21 +00:00
1aece432ba Implement narrow from a regular tensor to jagged tensor (#112770)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112770
Approved by: https://github.com/cpuhrsch
2023-11-13 19:09:59 +00:00
9f3e378125 [nested tensor]add split and layer_norm_backward operations (#113108)
Summary:
Add split and layer_norm_backward.

Note: It is non trivial to support split_with_sizes backward so adding the split operation to support the use case in the model.

Test Plan: unit tests

Differential Revision: D51052966

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113108
Approved by: https://github.com/soulitzer
2023-11-08 07:44:35 +00:00
c2084da14a [NT] Backward support for broadcasting binary ops (#112519)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112519
Approved by: https://github.com/jbschlosser
ghstack dependencies: #113031
2023-11-07 00:03:21 +00:00
53fff56ab8 Graph break cleanly for test_nestedtensor (#112662)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112662
Approved by: https://github.com/jbschlosser
2023-11-03 07:20:43 +00:00
24f217ee64 [Nested tensor] Add more ops in Python subclass nested tensor (#112302)
Summary: Add dropout, split_with_sizes, and silu operations in python subclass nested tensor

Test Plan: unit tests

Differential Revision: D50676812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112302
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
2023-10-31 20:57:05 +00:00
668c3b3f3b Add embedding op to jagged NT (#112288)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112288
Approved by: https://github.com/cpuhrsch
2023-10-28 01:29:17 +00:00
73170b23d4 Add compile support for NT unbind (#111531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111531
Approved by: https://github.com/ezyang
2023-10-23 21:16:20 +00:00
ba2ba9621c More NT subclass op support for SAM (#111253)
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111253
Approved by: https://github.com/soulitzer, https://github.com/cpuhrsch
2023-10-18 21:21:28 +00:00
2dc1726ab7 Compile NestedTensor with AOTAutograd (#110529)
This PR has a number of changes that improve subclass support for AOTAutograd/Inductor in general:
-  previously if a subclass does extra aliasing between graph outputs/inputs in a way, the partitioner would complain because grad_outputs are the outputs reused as-is. Now we do a view_as(self) to workaround this.
- Use dense -> dense metadata when working with fwd_output_strides during backward. This is important since the stride information comes from inductor which sees the dense to dense graph.
- Inductor requires that the inputs to the compiled backward to match some expected strides computed during compilation. We make sure to make the inner tensors of the subclass contiguous (previously, we only made the subclass itself contiguous)

Changes specific to NestedTensor relevant to compilation:
- Properly handle the case where `__tensor_unflatten__` is passed non-symbolic dense tensors and with meta extracted from fake subclasses.
- Skip var_to_range logic for singleton int
- Skip size hint logic in inductor for singleton int

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110529
Approved by: https://github.com/bdhirsh
2023-10-17 21:17:10 +00:00
4c01686027 Public API for constructing NT with jagged layout from tensor list (#111078)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111078
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #109123
2023-10-13 03:27:41 +00:00