249 Commits

Author SHA1 Message Date
8ba555ec8a Fix where() for NJT (#141500)
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](78491d6afc/tools/autograd/derivatives.yaml (L432-L434))). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of #140736 due to softshrink's derivative formula).

**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
    * Uses limited `broadcast_tensors()` / `broadcast_to()` support
    * Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`

**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-26 20:13:27 +00:00
cffeb83f15 Revert "Forward / backward NJT support for several activation functions (#140736)"
This reverts commit daaecb96d6b8049f8ca95974cd8a45b2fb9d4e28.

Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/malfet due to Take 2, of stack revert your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498479702))
2024-11-25 16:27:00 +00:00
e0f9ec4a25 Revert "Initial NJT testing over dim type / views (#140161)"
This reverts commit 730caf0aed187ce5c1c36fae7e9ae1f700585280.

Reverted https://github.com/pytorch/pytorch/pull/140161 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498358652))
2024-11-25 15:40:54 +00:00
58727b6f5f Revert "NJT unsqueeze() fixes (#141392)"
This reverts commit 48409a5cc6b14b6a5237beb6263a436d309afcd2.

Reverted https://github.com/pytorch/pytorch/pull/141392 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498358652))
2024-11-25 15:40:54 +00:00
48409a5cc6 NJT unsqueeze() fixes (#141392)
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #140736, #140161
2024-11-25 08:08:38 +00:00
730caf0aed Initial NJT testing over dim type / views (#140161)
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #140736
2024-11-25 08:08:38 +00:00
daaecb96d6 Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
2024-11-25 08:08:31 +00:00
2e7ba0b194 Revert "Switch to using Python nested int (#141166)"
This reverts commit e2e8a7fa2e519433a4ec1071f80d2f6f843c6300.

Reverted https://github.com/pytorch/pytorch/pull/141166 on behalf of https://github.com/clee2000 due to broke docs [GH job link](https://github.com/pytorch/pytorch/actions/runs/11980936976/job/33406870951) [HUD commit link](e2e8a7fa2e) ([comment](https://github.com/pytorch/pytorch/pull/141166#issuecomment-2495112297))
2024-11-22 23:54:36 +00:00
0be0c944b1 Revert "Forward / backward NJT support for several activation functions (#140736)"
This reverts commit af70f5e04c69839a1a0e08942254c170dc4c3d61.

Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/huydhn due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2495075871))
2024-11-22 23:15:55 +00:00
e2e8a7fa2e Switch to using Python nested int (#141166)
Doesn't seem to noticeably slow down eager - TestNestedTensorSubclass tests with and without the PR finished in similar amounts of time (around 57s, 58s)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141166
Approved by: https://github.com/ezyang
2024-11-22 22:12:25 +00:00
af70f5e04c Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
2024-11-22 22:05:53 +00:00
41f315417c Fix NJT linear_backward() memory usage (#141163)
Fixes #141112

The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls.
d5ee1d1b58/aten/src/ATen/native/nested/NestedTensorBackward.cpp (L37-L70)

Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141163
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-21 15:22:45 +00:00
b63a84804c Allow NJT by default for weights_only torch.load (take 2) (#140739)
Per discussion with @malfet, only allow weights_only unpickler to load NJT if `torch.nested` and `torch._dynamo`  are imported

(this is slightly weird as technically `torch.nested` is actually imported by default and `torch._dynamo.decorators._DimRange` is actually what needs to be imported)

we can't import this from `torch.nested` as this would
- undo dynamo lazy import
- cause circular import

===========================
Redo of https://github.com/pytorch/pytorch/pull/140304 caused issues as `torch.nested._internal.foo` needs to be imported, which causes issues like

```python
torch/_weights_only_unpickler.py", line 339, in load
    if full_path in _get_allowed_globals():
torch/_weights_only_unpickler.py", line 188, in _get_allowed_globals
    torch.nested._internal.nested_tensor.NestedTensor
AttributeError: module 'torch.nested' has no attribute '_internal'
```

**This likely wasn't caught in our CI because imports are global during unit tests(?), so we use subprocess to properly test this time**

Differential Revision: [D65961691](https://our.internmc.facebook.com/intern/diff/D65961691)

@jbschlosser
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140739
Approved by: https://github.com/malfet
2024-11-19 02:44:53 +00:00
c1f21bf2b6 Made FlexAttention error on subgraph lowering failure (#140331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140331
Approved by: https://github.com/drisspg
2024-11-17 02:43:58 +00:00
9c678af9f9 Misc. non-contig NJT fixes (#140160)
This PR contains several fixes related to non-contiguous NJTs:
1. Propagates `lengths` through op calls appropriately (see desc of #138098)
    * SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()`
2. Allows non-contig NJTs in unsqueeze / transpose / select
3. Expands padded dense -> NJT conversion to support non-contig NJTs
4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140160
Approved by: https://github.com/cpuhrsch
2024-11-09 01:18:26 +00:00
3abbde976d Allow any single non-batch dim to be ragged for NJT (#137125)
Fixes #137512

Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this.

At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137125
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-06 18:50:08 +00:00
ddb291a881 Fix and test several NJT reductions (#139317)
I'm sick of reductions not working properly - spotty dim coverage, missing backwards, etc. This PR fixes quite a bit.

It applies to the following ops:
* `sum` / `mean` / `prod`
* `all` / `any`
* `amin` / `amax`
* `min` / `max`
* `argmin` / `argmax`

The general reduction logic has been factored out into a helper `_apply_reduction(func, func_name, identity_element, *args, **kwargs)`. The idea is that by providing a valid identity element, we can utilize conversions to padded dense when needed for reducing over the ragged dim.

Extensive test coverage includes:
* reductions across ragged dim
* reductions across non-batch, non-ragged dims
* reductions across both batch and ragged dims
* multiple dim reductions (for ops that support this)
* full reduction -> scalar

Bonus: the PR includes backwards fixes for `sum` and `mean`, which have never worked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139317
Approved by: https://github.com/cpuhrsch
2024-10-31 20:55:38 +00:00
ad637a4c5c Add support for index_put_ in NT (#135722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135722
Approved by: https://github.com/jbschlosser
2024-10-30 17:17:59 +00:00
5861279f47 Revert "Add support for index_put_ in NT (#135722)"
This reverts commit b4836e5b5ce2891e9af21790d255720e2dbf8e91.

Reverted https://github.com/pytorch/pytorch/pull/135722 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing on ROCm ([comment](https://github.com/pytorch/pytorch/pull/135722#issuecomment-2445651914))
2024-10-30 01:53:55 +00:00
b4836e5b5c Add support for index_put_ in NT (#135722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135722
Approved by: https://github.com/jbschlosser
2024-10-30 00:03:21 +00:00
2b577ae58f Implement NJT embedding backward (#138627)
Fixes #138352

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138627
Approved by: https://github.com/jbschlosser
2024-10-29 18:44:58 +00:00
8785353f2f Fix tensor subclass + dynamic shapes in torch.compile + aot autograd (#125941)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125941
Approved by: https://github.com/bdhirsh
ghstack dependencies: #133337
2024-10-28 21:58:59 +00:00
8ba9063002 FlexAttention support for NJT (#136792)
This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
    * `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
      * Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.

Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query)  # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)

def causal_score_mod(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```

TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
    * Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841
2024-10-28 20:01:27 +00:00
f089d5ffef Improve input validation for NJT pointwise ops (#138602)
Before this PR, NJT would dispatch e.g. `NJT * nested_int` to `mul.Tensor`, wrongly interpreting the SymInt as a tensor and outputting garbage. This PR verifies that there are no nested ints in the list of args before dispatching for pointwise ops.

I originally tried checking that `the number of passed tensor args == the number of func schema tensor args`, but this wrongly disallows `nt * 2`, which (non-intuitively to me at least at first) dispatches via the `mul.Tensor` overload.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138602
Approved by: https://github.com/soulitzer
2024-10-22 20:13:12 +00:00
134f6cda7e Support record_stream() for NJT (#137099)
Does what it says on the tin. I believe the right behavior here is to ensure that `record_stream()` is called on all tensor components of the NJT to ensure they all live until stream computation is complete.

This is an ask from torchrec as the op is used there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137099
Approved by: https://github.com/ngimel
2024-10-21 21:10:42 +00:00
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
ecc5e05854 Refactor NJT min / max seqlen handling for convenience (#138130)
There's an annoying pattern emerging for pulling out the NJT min / max seqlen ints if they exist without computing / caching if they don't. This PR introduces private convenience functions to simplify handling this and avoiding redundant checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138130
Approved by: https://github.com/soulitzer
2024-10-17 17:28:39 +00:00
906fe05895 Naive impls for NJT matmul (#138121)
Our matmul support is abysmal - let's at least get this working and do it performantly later.

Bonus: implements `bmm` as well.

jagged <-> padded dense conversions are utilized when possible, and an unbind-based fallback otherwise (the former works with torch.compile and the latter doesn't). Some testing is missing because we don't have factory function support yet :(
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138121
Approved by: https://github.com/cpuhrsch
2024-10-17 01:31:46 +00:00
3e2f276a14 Fix to() on non-contiguous NJTs (#137124)
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): #137275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137124
Approved by: https://github.com/soulitzer
ghstack dependencies: #137030, #137031
2024-10-08 15:11:05 +00:00
e95b230fd8 Fix NJT serialization (#137031)
Fixes #129366

Since NJT has custom serialization logic, we need an NJT-specific fix to clear out cached sizes / strides PyCapsules. Eventually, we should switch NJT to use the default serialization logic, but this depends on #125622 being addressed.

This PR also makes serialization more complete by explicitly handling `lengths`, `ragged_idx`, and the `metadata_cache`, ensuring working operation for both contiguous and non-contiguous NJTs,
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137031
Approved by: https://github.com/soulitzer
ghstack dependencies: #137030
2024-10-02 21:41:35 +00:00
991f8f8ec3 Bias gradient calculation for NJT linear backward (#136660)
Previously NYI - @mikaylagawarecki needs it for Transformers.

Fixes #136652
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136660
Approved by: https://github.com/soulitzer
2024-09-26 21:38:10 +00:00
4150ab44a4 Fix composite op redispatch for NJT in inference mode (#134683)
Prior to this PR, calling `reshape()` under `inference_mode()` would throw a `NotImplementedError`. This is because `inference_mode()` disables autograd key dispatch, incidentally preventing the decomposition of reshape for NJT.

This PR fixes this by redispatching on the `CompositeImplicitAutogradNestedTensor` key whenever a composite implicit op is encountered in `NJT.__torch_dispatch__()`. This fixes reshape and any other composite implicit ops underneath `inference_mode()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134683
Approved by: https://github.com/soulitzer, https://github.com/albanD
ghstack dependencies: #136566
2024-09-26 14:10:53 +00:00
888744bd36 NJT binary pointwise broadcasting support via jagged <-> padded dense conversion (#133021)
Related: #132695

This PR uses padded dense <-> jagged conversions to handle binary pointwise broadcasting of (NT, T) and (T, NT). This includes:
* `(B, j0, D) + (1, 1, 1)`
* `(B, j0, D) + (B, 1, 1)`
* `(B, j0, D) + (B, 1, D)`
* etc.

This PR also adds (hacky) support for bool inputs to the jagged <-> padded dense conversions. The underlying CUDA kernels do not support integer / bool inputs; so the following workaround is employed: `convert input -> half, run conversion kernel, convert output -> bool`. Note that this bool support is needed specifically for the backward formula of `fmax`, and likely others.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133021
Approved by: https://github.com/cpuhrsch
2024-09-24 19:11:49 +00:00
a8382847f4 Support rms_norm() for NJT (#135872)
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #125947
2024-09-17 18:09:20 +00:00
31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00
3117f2cf67 Revert "[BE]: Update mypy to 1.11.2 (#133816)"
This reverts commit 55299cfc223fa838aadd8d6d6fa3ed541fa5acd1.

Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684))
2024-09-16 09:11:16 +00:00
55299cfc22 [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-14 21:40:36 +00:00
06bc717410 Fix sum() forward for NJT (#131945)
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
2024-09-14 00:58:03 +00:00
525bec804c NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-12 17:54:25 +00:00
70a65a8bd5 Revert "NJT <-> padded dense conversions (#125947)"
This reverts commit 09a5e88bef04d5485b70d8f65f46a675aaa52942.

Reverted https://github.com/pytorch/pytorch/pull/125947 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing dynamo test 09a5e88bef, maybe a landrace ([comment](https://github.com/pytorch/pytorch/pull/125947#issuecomment-2339228570))
2024-09-09 22:01:09 +00:00
09a5e88bef NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-09 19:37:32 +00:00
defb515306 [NJT]Add permute ops support (#135336)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336
Approved by: https://github.com/davidberard98
2024-09-08 21:00:41 +00:00
4b4ba7ab06 [NJT] Support NJT SDPA + meta-device flop counting (#134289)
A user wants to use the flop counter with meta devices. This previously caused problems for SDPA+NJT:

1. autocast check: `torch.is_autocast_enabled("meta")` fails because `meta` is not valid for autocasting. If we skip this, we run into the next error
2. math backend: conversion to NST requires getting concrete offsets in a list of python integers, which doesn't work on a meta tensor b2eb0e8c6a/torch/nested/_internal/sdpa.py (L809-L815)
3. (fixed in the previous PR, #134288) - if we force using flash attention backend for flop counting, `_flash_attention_forward` previously didn't support meta tensors.

In this PR, we check specifically for FlopCounterMode, and, if it's enabled and combined with meta tensors, (a) skip autocasting and (b) force it down the flash attention path. This isn't generally safe for tracing (e.g. if you actually care which kernels you are running), but in the absence of actual device information, we have to make some assumptions. By specifically checking for FlopCounterMode, this should reduce the chance of unintended side effects for other meta tensor users.

Note: fake tensor would solve a bunch of these issues, but it's not a viable solution right now for the user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134289
Approved by: https://github.com/soulitzer
ghstack dependencies: #134288
2024-08-29 03:43:42 +00:00
ed86ac2f25 [BE] typing for decorators - fx/_compatibility (#134054)
Summary: See #131429

Test Plan: unit tests pass

Differential Revision: D61493706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134054
Approved by: https://github.com/oulgen
2024-08-26 04:00:27 +00:00
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
44fa9f991c [NJT] add aten.to.dtype support (#134164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134164
Approved by: https://github.com/davidberard98
2024-08-22 16:59:38 +00:00
2e1830c7c8 Implement 2D version of masked_select for nestedtensors (#133889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133889
Approved by: https://github.com/soulitzer
2024-08-20 21:46:32 +00:00
4af4910b1a Reland "Construct NJT without graph breaks" (#133196)
This reverts commit 154d40ca488e6979ce9c2de89d8a35b53129ebea.

and adds changes from https://github.com/pytorch/pytorch/pull/133061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133196
Approved by: https://github.com/ezyang
ghstack dependencies: #133145
2024-08-14 01:11:13 +00:00
656465fc77 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit ed97fb77f9a9d9d815f4975caccbc961ebbcb714.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/izaitsevfb due to fails internal jobs, see [S440348](https://www.internalfb.com/sevmanager/view/440348) ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2285051164))
2024-08-12 23:14:19 +00:00
05de2b2d0f Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145
Approved by: https://github.com/YuqingJ
2024-08-10 03:11:16 +00:00