Commit Graph

159 Commits

Author SHA1 Message Date
0ce945790e [NJT] Fix schema validation error in jagged functions (#165307)
Fixes #161812
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307
Approved by: https://github.com/soulitzer
2025-10-13 17:59:18 +00:00
2035f6b2e6 use check_size instead of check_is_size in ops.py (#164668)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164668
Approved by: https://github.com/angelayi
ghstack dependencies: #164664, #164665, #164667
2025-10-08 14:23:38 +00:00
5f18f240de Add initial suppressions for pyrefly (#164177)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
`python3 scripts/lintrunner.py`
`pyrefly check`

---

Pyrefly check before: https://gist.github.com/maggiemoss/3a0aa0b6cdda0e449cd5743d5fce2c60
After:

```
 INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml`
 INFO 0 errors (1,063 ignored)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164177
Approved by: https://github.com/Lucaskabela
2025-10-02 20:57:41 +00:00
6a31f42da4 Fix NestedTensor max/min operations for integer dtypes. (#162273)
Fixes: https://github.com/pytorch/pytorch/issues/162049

### Summary

max_dim and min_dim functions incorrectly used torch.finfo()
for all dtypes, causing TypeError for integer tensors.

### Changes

- Use torch.iinfo() for integer dtypes instead of torch.finfo().
- Add CPU test: `test_jagged_max_min_dtypes` covering `int8, int16, int32, int64, uint8, float16, bfloat16, float32 and float64`

### Testing

Before Fix:

`python -m pytest test/test_nestedtensor.py -k "test_jagged_max_min_dtypes" -v`

Output:

```
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_bfloat16 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float16 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float32 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float64 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int16 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0005s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int32 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0005s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int64 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0004s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int8 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0004s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_uint8 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
```

After Fix:

`python -m pytest test/test_nestedtensor.py -k "test_jagged_max_min_dtypes" -v`

Output:

```
Running 9 items in this shard

test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_bfloat16 PASSED [0.0086s]                                                                                                                   [ 11%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float16 PASSED [0.0011s]                                                                                                                    [ 22%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float32 PASSED [0.0011s]                                                                                                                    [ 33%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float64 PASSED [0.0011s]                                                                                                                    [ 44%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int16 PASSED [0.0009s]                                                                                                                      [ 55%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int32 PASSED [0.0010s]                                                                                                                      [ 66%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int64 PASSED [0.0010s]                                                                                                                      [ 77%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int8 PASSED [0.0010s]                                                                                                                       [ 88%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_uint8 PASSED [0.0011s]                                                                                                                       [100%]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162273
Approved by: https://github.com/Skylion007, https://github.com/jbschlosser
2025-10-02 18:46:27 +00:00
fd785b1762 Add NestedTensor dispatch for _is_any_true/_is_all_true (#162096)
Fixes: https://github.com/pytorch/pytorch/issues/161818

### Summary
Add NestedTensor support for `_is_any_true` and `_is_all_true`.

### Changes
- Register dispatch for `aten._is_any_true.default` and
  `aten._is_all_true.default`
- Add CPU tests:
  - `test_is_any_true_jagged`: dispatch_matches_values_buffer,
    all_false_returns_false, one_true_returns_true
  - `test_is_all_true_jagged`: dispatch_matches_values_buffer,
    all_true_returns_true, any_false_returns_false

### Testing

Before Fix:

`pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v`

Output:
```
FAILED [0.0129s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu - NotImplementedError: aten._is_all_true.default
FAILED [0.0007s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu - NotImplementedError: aten._is_any_true.default
```

After Fix:

`pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v`

Output:

```
Running 2 items in this shard

test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu PASSED [0.0277s]                                                                                                                               [ 50%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu PASSED [0.0013s]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162096
Approved by: https://github.com/jbschlosser
2025-09-22 20:22:44 +00:00
bf28990c3d Add support for NestedTensor share_memory_ (#162272)
Fixes: https://github.com/pytorch/pytorch/issues/161915

### Summary

Implements share_memory_() support for NestedTensor!

### Changes

- Added share_memory_() method to NestedTensor class.
  - Shares storage for all NestedTensor components: _values, _offsets, _lengths, and cached seqlen tensors.
  - Guard for CUDA Tensors.

### Testing

Before Fix:

`pytest -q test/test_nestedtensor.py -k "test_share_memory" -v`

Output:

```
Running 1 items in this shard

test/test_nestedtensor.py Fatal Python error: Segmentation fault
```

After Fix:

`pytest -q test/test_nestedtensor.py -k "test_share_memory" -v`

Output:

```
Running 1 items in this shard

test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_share_memory_cpu PASSED [0.0753s]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162272
Approved by: https://github.com/jbschlosser
2025-09-22 19:59:58 +00:00
d08cabe314 [BC Breaking] Remove flex + njt code paths (#161734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161734
Approved by: https://github.com/jbschlosser
2025-09-16 00:13:56 +00:00
189a054cfb Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. [attempt2] (#160869)
[relanding again after fixing internal build]
Summary:
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Approved by: https://github.com/ezyang

Test Plan:
contbuild & OSS CI, see e444cd24d4

Rollback Plan:

Differential Revision: D80435179

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160869
Approved by: https://github.com/ezyang
2025-09-08 22:59:13 +00:00
ac9ccd0dc2 Add return-max-scores to flex-attention (#161667)
# Summary

### Update

API

```Py
class AuxRequest(NamedTuple):
    """Request which auxiliary outputs to compute from flex_attention.

    Each field is a boolean indicating whether that auxiliary output should be computed.
    """

    lse: bool = False
    max_scores: bool = False

class AuxOutput(NamedTuple):
    """Auxiliary outputs from flex_attention operation.

    Fields will be None if not requested, or contain the tensor if requested.
    """

    lse: Optional[Tensor] = None
    max_scores: Optional[Tensor] = None

  out_only = flex_attention(query, key, value, score_mod)
  out_max, aux_max = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(max_scores=True),
  )
  out_both, aux_both = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
        )
```

Returns the max post mod scores from flex attention.

Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.

Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now

Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.

We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.

### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible  but since max is essentially 1 hot 	on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).

For now no grad, we can re-visit if needed.

## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.

```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,     ┆ 249.514658    ┆ 243.078974   ┆ 6.435684  ┆ 2.647569  │
│                ┆                ┆ 2048, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 57.971274     ┆ 56.633641    ┆ 1.337633  ┆ 2.361905  │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 280.71254     ┆ 275.686991   ┆ 5.025549  ┆ 1.822918  │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,    ┆ 152.970031    ┆ 150.489109   ┆ 2.480923  ┆ 1.648573  │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘

🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D)  ┆ TFlops (base) ┆ TFlops (max) ┆ delta    ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                    ┆ ---           ┆ ---          ┆ ---      ┆ ---       │
│ str            ┆ str            ┆ str                    ┆ f64           ┆ f64          ┆ f64      ┆ f64       │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,      ┆ 249.514658    ┆ 243.078974   ┆ 6.435684 ┆ 2.647569  │
│                ┆                ┆ 2048, 64)              ┆               ┆              ┆          ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 57.971274     ┆ 56.633641    ┆ 1.337633 ┆ 2.361905  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 280.71254     ┆ 275.686991   ┆ 5.025549 ┆ 1.822918  │
│                ┆                ┆ 1024, 128)             ┆               ┆              ┆          ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,     ┆ 152.970031    ┆ 150.489109   ┆ 2.480923 ┆ 1.648573  │
│                ┆                ┆ 16384, 64)             ┆               ┆              ┆          ┆           │
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,      ┆ 161.031318    ┆ 158.597808   ┆ 2.43351  ┆ 1.534391  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘

🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 4,      ┆ 175.546923    ┆ 177.81205    ┆ -2.265127 ┆ -1.273888 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4,     ┆ 156.282597    ┆ 158.209134   ┆ -1.926537 ┆ -1.217715 │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16,     ┆ 232.542929    ┆ 235.140136   ┆ -2.597207 ┆ -1.104536 │
│                ┆                ┆ 2048, 128)            ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 169.652791    ┆ 171.475986   ┆ -1.823195 ┆ -1.063236 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
2025-09-08 22:44:48 +00:00
b82aa3df20 Revert "Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (#159197)"
This reverts commit e444cd24d48b3a46f067974f2cc157f5ed27709f.

Reverted https://github.com/pytorch/pytorch/pull/159197 on behalf of https://github.com/laithsakka due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/159197#issuecomment-3195436668))
2025-08-18 07:22:13 +00:00
e444cd24d4 Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (#159197)
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159197
Approved by: https://github.com/ezyang
2025-08-16 09:15:58 +00:00
50580b5053 Add minimal nn.functional.log_softmax support for NestedTensor (#159662)
This only works for the jagged layout and for the non-batch and non-jagged dimensions.

I did this mostly by copy-pasting from the existing softmax implementation, but it seems fairly straightforward and I think it should work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159662
Approved by: https://github.com/jbschlosser
2025-08-06 20:34:02 +00:00
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
e01a5e9e1e Small improvements to NJT matrix multiplies (#146405)
Fixes #146404

Adds changes to the matmul and matmul_backward operation for nested jagged tensors, to support back propagation when the output is a regular strided tensor.
This required adding support for the nested matmul operation to work when the nested tensor wasn't 'self', i.e
`A@B` where `A` isn't nested but `B` is.

The operation schemas had to be updated to reflect that either input can be a strided tensor instead (and the gradient), so an extra assertion is added in an edge case where neither input is nested.

Unit tests are also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146405
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
2025-02-06 04:51:12 +00:00
1ba1b7b597 Support remaining *_like factory functions for NJT (#144889)
Fixes #144761

This PR adds NJT impls for those *_like functions that were previously missing:
* `full_like()`
* `rand_like()`
* `randint_like()`

It also fixes a bug in existing *_like functions when a new device is specified. Fix is to also transfer `offsets` / `lengths` to the new device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144889
Approved by: https://github.com/soulitzer
2025-01-27 21:33:51 +00:00
5725462cd8 Update NJT linear_backward to return non-aliased tensor bias grad (#145399)
Fixes https://github.com/pytorch/pytorch/issues/141292

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145399
Approved by: https://github.com/jbschlosser
ghstack dependencies: #145520, #145531, #145533
2025-01-25 00:58:04 +00:00
128f3627b1 Implement backward for NJT matmul (#144587)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

This PR implements missing backward support for NJT matmul. Notably, for dense tensors, matmul dispatches to bmm. However, due to historical reasons related to NST, NJT handles matmul directly, and thus can't rely on the CompositeImplicit impl of matmul to get the derivative formula.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144587
Approved by: https://github.com/soulitzer
ghstack dependencies: #144586
2025-01-21 18:27:50 +00:00
af204135d8 Fix NJT fill.Scalar for contiguous inputs (#144586)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

This PR implements the missing `fill.Scalar` support, which works fine for contiguous inputs, but there is still some AOTAutograd debugging required to handle non-contiguous transposed NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144586
Approved by: https://github.com/soulitzer
2025-01-21 18:22:08 +00:00
805c4b597a PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145202
Approved by: https://github.com/bobrenjc93
2025-01-20 22:37:26 +00:00
b63b81410c Fix NJT frexp() to handle both outputs (#144585)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144585
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583, #144584
2025-01-18 15:59:56 +00:00
3ee531f8b9 Support NJT chunk() backward on batch dim (#144584)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144584
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583
2025-01-18 15:58:24 +00:00
a8ef423fed Fix NJT min / max backward() for non-ragged reductions (#144583)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

`value_selecting_reduction_backward()` is used in the backward for min / max, so this PR implements it for NJT. Notably, this isn't enough for reducing over the ragged dim, since that results in a dense tensor and thus NJT's torch_dispatch will not be called for this op. We need factory function support for nested ints to fix that case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144583
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582
2025-01-17 20:57:11 +00:00
f1cbf4b1b5 Enable ruff's unused variable checking everywhere in pytorch (#136965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136965
Approved by: https://github.com/cyyever, https://github.com/albanD
2024-12-22 02:33:11 +00:00
3f99682fbd NJT linear_backward should not return inner tensor as-is (#143333)
Fixes debug=1 use-count checks https://github.com/pytorch/pytorch/actions/runs/12187808902/job/34002323481#step:22:2521

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143333
Approved by: https://github.com/jbschlosser
2024-12-18 00:15:18 +00:00
661d1f0372 [aotd] non-contiguous NestedTensor mutation in compile (#139630)
Allow mutations mutations for subclasses that are non-contiguous.

Changes:

Removing assert in collect_metadata_analysis

Main requested testcase:
Compilation of NJT.index_put()

Adding test in test_nestedtensor.py, that compiles NJT.index_put()

It is  decomposed to NJT split,unbind, which  needed additional `torch._check`, `torch._check_is_size` for NJT.unbind()  and guard_size_oblivious() usage in _meta_registrations and _inductor/lowering.py.

Special case:
If tangent is mutated outside of the graph, it does not participate in backward graph. Autograd in this case will set this tangent to zeros tensor.

We handle it separately in CompiledFunction.backward: not doing any processing for this tangent and broadcast to number of expected subclass unwrapped arguments.

disabling for dynamo 2 tests:
1/ For nested tensor - symbolic shapes issue on nested_tensor index operation that does splits [0, 0, 0] - there is a failure with "pending unbacked symints". This PR does not add more .tolist()/item() ops than it was before.

2/ As we do not fail with exception in collect_metadata_analysis new paths for dynamo started working and it started failing with smth strange that set_ in storage_offset (because of test for views) handling updates storage "cpu" -> "meta"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139630
Approved by: https://github.com/bdhirsh
2024-12-06 12:18:46 +00:00
e803a3d83a Fix reductions for NJTs with ragged_idx != 1 (#142173)
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125.

This PR:
* Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this.
* Fixes outputs across keepdim settings when reducing over ragged / batch dims.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173
Approved by: https://github.com/drisspg
2024-12-06 01:23:17 +00:00
c9e2b3fefe NJT: Return correct number of outputs for chunk() on the batch dim (#141604)
Old logic was completely wrong, returning `chunk_size` chunks instead of the intended number. The original test didn't catch this because `chunk_size == num_chunks` :p New OpInfo-based testing covers it though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141604
Approved by: https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392, #141506
2024-11-27 02:31:23 +00:00
43121b6f0d Adjust output NJT ragged_idx for reductions and select() (#141506)
This fixes some bugs when performing reductions / select() on dims before the ragged dim. In this case, the output NJT has a smaller number of dims, and its ragged_idx should reflect that correctly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141506
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392
2024-11-27 02:25:53 +00:00
23793cf93d NJT unsqueeze() fixes (#141392)
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736, #140161
2024-11-26 22:38:35 +00:00
9ee5d6f83c Initial NJT testing over dim type / views (#140161)
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736
2024-11-26 22:08:08 +00:00
869d629c0f Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500
2024-11-26 21:19:58 +00:00
8ba555ec8a Fix where() for NJT (#141500)
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](78491d6afc/tools/autograd/derivatives.yaml (L432-L434))). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of #140736 due to softshrink's derivative formula).

**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
    * Uses limited `broadcast_tensors()` / `broadcast_to()` support
    * Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`

**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-26 20:13:27 +00:00
cffeb83f15 Revert "Forward / backward NJT support for several activation functions (#140736)"
This reverts commit daaecb96d6b8049f8ca95974cd8a45b2fb9d4e28.

Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/malfet due to Take 2, of stack revert your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498479702))
2024-11-25 16:27:00 +00:00
e0f9ec4a25 Revert "Initial NJT testing over dim type / views (#140161)"
This reverts commit 730caf0aed187ce5c1c36fae7e9ae1f700585280.

Reverted https://github.com/pytorch/pytorch/pull/140161 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498358652))
2024-11-25 15:40:54 +00:00
58727b6f5f Revert "NJT unsqueeze() fixes (#141392)"
This reverts commit 48409a5cc6b14b6a5237beb6263a436d309afcd2.

Reverted https://github.com/pytorch/pytorch/pull/141392 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498358652))
2024-11-25 15:40:54 +00:00
48409a5cc6 NJT unsqueeze() fixes (#141392)
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #140736, #140161
2024-11-25 08:08:38 +00:00
730caf0aed Initial NJT testing over dim type / views (#140161)
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #140736
2024-11-25 08:08:38 +00:00
daaecb96d6 Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
2024-11-25 08:08:31 +00:00
0be0c944b1 Revert "Forward / backward NJT support for several activation functions (#140736)"
This reverts commit af70f5e04c69839a1a0e08942254c170dc4c3d61.

Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/huydhn due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2495075871))
2024-11-22 23:15:55 +00:00
af70f5e04c Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
2024-11-22 22:05:53 +00:00
41f315417c Fix NJT linear_backward() memory usage (#141163)
Fixes #141112

The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls.
d5ee1d1b58/aten/src/ATen/native/nested/NestedTensorBackward.cpp (L37-L70)

Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141163
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-21 15:22:45 +00:00
c1f21bf2b6 Made FlexAttention error on subgraph lowering failure (#140331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140331
Approved by: https://github.com/drisspg
2024-11-17 02:43:58 +00:00
9c678af9f9 Misc. non-contig NJT fixes (#140160)
This PR contains several fixes related to non-contiguous NJTs:
1. Propagates `lengths` through op calls appropriately (see desc of #138098)
    * SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()`
2. Allows non-contig NJTs in unsqueeze / transpose / select
3. Expands padded dense -> NJT conversion to support non-contig NJTs
4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140160
Approved by: https://github.com/cpuhrsch
2024-11-09 01:18:26 +00:00
ddb291a881 Fix and test several NJT reductions (#139317)
I'm sick of reductions not working properly - spotty dim coverage, missing backwards, etc. This PR fixes quite a bit.

It applies to the following ops:
* `sum` / `mean` / `prod`
* `all` / `any`
* `amin` / `amax`
* `min` / `max`
* `argmin` / `argmax`

The general reduction logic has been factored out into a helper `_apply_reduction(func, func_name, identity_element, *args, **kwargs)`. The idea is that by providing a valid identity element, we can utilize conversions to padded dense when needed for reducing over the ragged dim.

Extensive test coverage includes:
* reductions across ragged dim
* reductions across non-batch, non-ragged dims
* reductions across both batch and ragged dims
* multiple dim reductions (for ops that support this)
* full reduction -> scalar

Bonus: the PR includes backwards fixes for `sum` and `mean`, which have never worked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139317
Approved by: https://github.com/cpuhrsch
2024-10-31 20:55:38 +00:00
ad637a4c5c Add support for index_put_ in NT (#135722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135722
Approved by: https://github.com/jbschlosser
2024-10-30 17:17:59 +00:00
5861279f47 Revert "Add support for index_put_ in NT (#135722)"
This reverts commit b4836e5b5ce2891e9af21790d255720e2dbf8e91.

Reverted https://github.com/pytorch/pytorch/pull/135722 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing on ROCm ([comment](https://github.com/pytorch/pytorch/pull/135722#issuecomment-2445651914))
2024-10-30 01:53:55 +00:00
b4836e5b5c Add support for index_put_ in NT (#135722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135722
Approved by: https://github.com/jbschlosser
2024-10-30 00:03:21 +00:00
2b577ae58f Implement NJT embedding backward (#138627)
Fixes #138352

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138627
Approved by: https://github.com/jbschlosser
2024-10-29 18:44:58 +00:00
8ba9063002 FlexAttention support for NJT (#136792)
This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
    * `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
      * Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.

Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query)  # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)

def causal_score_mod(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```

TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
    * Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841
2024-10-28 20:01:27 +00:00
f089d5ffef Improve input validation for NJT pointwise ops (#138602)
Before this PR, NJT would dispatch e.g. `NJT * nested_int` to `mul.Tensor`, wrongly interpreting the SymInt as a tensor and outputting garbage. This PR verifies that there are no nested ints in the list of args before dispatching for pointwise ops.

I originally tried checking that `the number of passed tensor args == the number of func schema tensor args`, but this wrongly disallows `nt * 2`, which (non-intuitively to me at least at first) dispatches via the `mul.Tensor` overload.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138602
Approved by: https://github.com/soulitzer
2024-10-22 20:13:12 +00:00