249 Commits

Author SHA1 Message Date
fdab48a7c1 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 07:36:18 +00:00
24520b8386 Revert "Enable all PIE rules on ruff (#165814)"
This reverts commit c79dfdc6550e872783aa5cb5fc9e86589bf18872.

Reverted https://github.com/pytorch/pytorch/pull/165814 on behalf of https://github.com/cyyever due to Need to cover more files ([comment](https://github.com/pytorch/pytorch/pull/165814#issuecomment-3417931863))
2025-10-18 07:21:08 +00:00
c79dfdc655 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 06:40:12 +00:00
0ce945790e [NJT] Fix schema validation error in jagged functions (#165307)
Fixes #161812
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307
Approved by: https://github.com/soulitzer
2025-10-13 17:59:18 +00:00
fb64da0791 [2/N] Use "is" in python type comparison (#165142)
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
2025-10-10 15:36:44 +00:00
2035f6b2e6 use check_size instead of check_is_size in ops.py (#164668)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164668
Approved by: https://github.com/angelayi
ghstack dependencies: #164664, #164665, #164667
2025-10-08 14:23:38 +00:00
5f18f240de Add initial suppressions for pyrefly (#164177)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
`python3 scripts/lintrunner.py`
`pyrefly check`

---

Pyrefly check before: https://gist.github.com/maggiemoss/3a0aa0b6cdda0e449cd5743d5fce2c60
After:

```
 INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml`
 INFO 0 errors (1,063 ignored)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164177
Approved by: https://github.com/Lucaskabela
2025-10-02 20:57:41 +00:00
6a31f42da4 Fix NestedTensor max/min operations for integer dtypes. (#162273)
Fixes: https://github.com/pytorch/pytorch/issues/162049

### Summary

max_dim and min_dim functions incorrectly used torch.finfo()
for all dtypes, causing TypeError for integer tensors.

### Changes

- Use torch.iinfo() for integer dtypes instead of torch.finfo().
- Add CPU test: `test_jagged_max_min_dtypes` covering `int8, int16, int32, int64, uint8, float16, bfloat16, float32 and float64`

### Testing

Before Fix:

`python -m pytest test/test_nestedtensor.py -k "test_jagged_max_min_dtypes" -v`

Output:

```
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_bfloat16 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float16 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float32 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float64 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0006s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int16 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0005s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int32 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0005s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int64 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0004s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int8 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
FAILED [0.0004s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_uint8 - TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'
```

After Fix:

`python -m pytest test/test_nestedtensor.py -k "test_jagged_max_min_dtypes" -v`

Output:

```
Running 9 items in this shard

test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_bfloat16 PASSED [0.0086s]                                                                                                                   [ 11%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float16 PASSED [0.0011s]                                                                                                                    [ 22%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float32 PASSED [0.0011s]                                                                                                                    [ 33%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_float64 PASSED [0.0011s]                                                                                                                    [ 44%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int16 PASSED [0.0009s]                                                                                                                      [ 55%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int32 PASSED [0.0010s]                                                                                                                      [ 66%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int64 PASSED [0.0010s]                                                                                                                      [ 77%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_int8 PASSED [0.0010s]                                                                                                                       [ 88%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_jagged_max_min_dtypes_cpu_uint8 PASSED [0.0011s]                                                                                                                       [100%]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162273
Approved by: https://github.com/Skylion007, https://github.com/jbschlosser
2025-10-02 18:46:27 +00:00
fd785b1762 Add NestedTensor dispatch for _is_any_true/_is_all_true (#162096)
Fixes: https://github.com/pytorch/pytorch/issues/161818

### Summary
Add NestedTensor support for `_is_any_true` and `_is_all_true`.

### Changes
- Register dispatch for `aten._is_any_true.default` and
  `aten._is_all_true.default`
- Add CPU tests:
  - `test_is_any_true_jagged`: dispatch_matches_values_buffer,
    all_false_returns_false, one_true_returns_true
  - `test_is_all_true_jagged`: dispatch_matches_values_buffer,
    all_true_returns_true, any_false_returns_false

### Testing

Before Fix:

`pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v`

Output:
```
FAILED [0.0129s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu - NotImplementedError: aten._is_all_true.default
FAILED [0.0007s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu - NotImplementedError: aten._is_any_true.default
```

After Fix:

`pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v`

Output:

```
Running 2 items in this shard

test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu PASSED [0.0277s]                                                                                                                               [ 50%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu PASSED [0.0013s]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162096
Approved by: https://github.com/jbschlosser
2025-09-22 20:22:44 +00:00
bf28990c3d Add support for NestedTensor share_memory_ (#162272)
Fixes: https://github.com/pytorch/pytorch/issues/161915

### Summary

Implements share_memory_() support for NestedTensor!

### Changes

- Added share_memory_() method to NestedTensor class.
  - Shares storage for all NestedTensor components: _values, _offsets, _lengths, and cached seqlen tensors.
  - Guard for CUDA Tensors.

### Testing

Before Fix:

`pytest -q test/test_nestedtensor.py -k "test_share_memory" -v`

Output:

```
Running 1 items in this shard

test/test_nestedtensor.py Fatal Python error: Segmentation fault
```

After Fix:

`pytest -q test/test_nestedtensor.py -k "test_share_memory" -v`

Output:

```
Running 1 items in this shard

test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_share_memory_cpu PASSED [0.0753s]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162272
Approved by: https://github.com/jbschlosser
2025-09-22 19:59:58 +00:00
d08cabe314 [BC Breaking] Remove flex + njt code paths (#161734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161734
Approved by: https://github.com/jbschlosser
2025-09-16 00:13:56 +00:00
189a054cfb Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. [attempt2] (#160869)
[relanding again after fixing internal build]
Summary:
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Approved by: https://github.com/ezyang

Test Plan:
contbuild & OSS CI, see e444cd24d4

Rollback Plan:

Differential Revision: D80435179

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160869
Approved by: https://github.com/ezyang
2025-09-08 22:59:13 +00:00
ac9ccd0dc2 Add return-max-scores to flex-attention (#161667)
# Summary

### Update

API

```Py
class AuxRequest(NamedTuple):
    """Request which auxiliary outputs to compute from flex_attention.

    Each field is a boolean indicating whether that auxiliary output should be computed.
    """

    lse: bool = False
    max_scores: bool = False

class AuxOutput(NamedTuple):
    """Auxiliary outputs from flex_attention operation.

    Fields will be None if not requested, or contain the tensor if requested.
    """

    lse: Optional[Tensor] = None
    max_scores: Optional[Tensor] = None

  out_only = flex_attention(query, key, value, score_mod)
  out_max, aux_max = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(max_scores=True),
  )
  out_both, aux_both = flex_attention(
      query,
      key,
      value,
      score_mod,
      return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
        )
```

Returns the max post mod scores from flex attention.

Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.

Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now

Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.

We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.

### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible  but since max is essentially 1 hot 	on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).

For now no grad, we can re-visit if needed.

## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.

```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,     ┆ 249.514658    ┆ 243.078974   ┆ 6.435684  ┆ 2.647569  │
│                ┆                ┆ 2048, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 57.971274     ┆ 56.633641    ┆ 1.337633  ┆ 2.361905  │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 280.71254     ┆ 275.686991   ┆ 5.025549  ┆ 1.822918  │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,    ┆ 152.970031    ┆ 150.489109   ┆ 2.480923  ┆ 1.648573  │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘

🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D)  ┆ TFlops (base) ┆ TFlops (max) ┆ delta    ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                    ┆ ---           ┆ ---          ┆ ---      ┆ ---       │
│ str            ┆ str            ┆ str                    ┆ f64           ┆ f64          ┆ f64      ┆ f64       │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 2048, 16,      ┆ 249.514658    ┆ 243.078974   ┆ 6.435684 ┆ 2.647569  │
│                ┆                ┆ 2048, 64)              ┆               ┆              ┆          ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 57.971274     ┆ 56.633641    ┆ 1.337633 ┆ 2.361905  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
│ noop           ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,      ┆ 280.71254     ┆ 275.686991   ┆ 5.025549 ┆ 1.822918  │
│                ┆                ┆ 1024, 128)             ┆               ┆              ┆          ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16,     ┆ 152.970031    ┆ 150.489109   ┆ 2.480923 ┆ 1.648573  │
│                ┆                ┆ 16384, 64)             ┆               ┆              ┆          ┆           │
│ causal         ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,      ┆ 161.031318    ┆ 158.597808   ┆ 2.43351  ┆ 1.534391  │
│                ┆                ┆ 1024, 64)              ┆               ┆              ┆          ┆           │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘

🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type      ┆ dtype          ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta     ┆ pct_delta │
│ ---            ┆ ---            ┆ ---                   ┆ ---           ┆ ---          ┆ ---       ┆ ---       │
│ str            ┆ str            ┆ str                   ┆ f64           ┆ f64          ┆ f64       ┆ f64       │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop           ┆ torch.bfloat16 ┆ (4, 16, 1024, 16,     ┆ 244.052884    ┆ 248.65129    ┆ -4.598406 ┆ -1.849339 │
│                ┆                ┆ 1024, 64)             ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 4,      ┆ 175.546923    ┆ 177.81205    ┆ -2.265127 ┆ -1.273888 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4,     ┆ 156.282597    ┆ 158.209134   ┆ -1.926537 ┆ -1.217715 │
│                ┆                ┆ 16384, 64)            ┆               ┆              ┆           ┆           │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16,     ┆ 232.542929    ┆ 235.140136   ┆ -2.597207 ┆ -1.104536 │
│                ┆                ┆ 2048, 128)            ┆               ┆              ┆           ┆           │
│ alibi          ┆ torch.bfloat16 ┆ (2, 16, 1024, 16,     ┆ 169.652791    ┆ 171.475986   ┆ -1.823195 ┆ -1.063236 │
│                ┆                ┆ 1024, 128)            ┆               ┆              ┆           ┆           │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
2025-09-08 22:44:48 +00:00
b82aa3df20 Revert "Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (#159197)"
This reverts commit e444cd24d48b3a46f067974f2cc157f5ed27709f.

Reverted https://github.com/pytorch/pytorch/pull/159197 on behalf of https://github.com/laithsakka due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/159197#issuecomment-3195436668))
2025-08-18 07:22:13 +00:00
e444cd24d4 Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (#159197)
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159197
Approved by: https://github.com/ezyang
2025-08-16 09:15:58 +00:00
50580b5053 Add minimal nn.functional.log_softmax support for NestedTensor (#159662)
This only works for the jagged layout and for the non-batch and non-jagged dimensions.

I did this mostly by copy-pasting from the existing softmax implementation, but it seems fairly straightforward and I think it should work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159662
Approved by: https://github.com/jbschlosser
2025-08-06 20:34:02 +00:00
ed03492238 Add check nested_tensor_from_jagged param jagged_dim >= 1 (#157770)
Fixes #157404

## Test Result

```bash
pytest test/test_nestedtensor.py

...............................................s..........ssssss.................................................................................................s.s..sssss..s...ss............................................................. [ 44%]
...........................................................sssss....sss...s.........ss....s....sss.........s.sss...s..s......s............s.sss.ss...............s.....................s....s......................s.s.....s....s..s..ssssssssss [ 59%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..ssssss.ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.ssssssss...............................s........................................... [ 74%]
.......sss...................................................................................................................................................................................................................................... [ 89%]
....sss..........................................................................................................................................................                                                                                [100%]

==================================================================================================== 1317 passed, 258 skipped in 2504.27s (0:41:44) ====================================================================================================
```

![image](https://github.com/user-attachments/assets/dcc8e46d-b88f-4580-b4ad-0999bad33ec9)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157770
Approved by: https://github.com/soulitzer

Co-authored-by: Jeffrey Wan <soulitzer@gmail.com>
2025-07-10 00:34:39 +00:00
db259bd6b8 [BE][12/16] fix typos in torch/ (#156602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156602
Approved by: https://github.com/justinchuby, https://github.com/albanD
ghstack dependencies: #156318, #156320
2025-07-02 22:55:29 +00:00
e95e8eed0a mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-06-14 18:18:43 +00:00
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
0a7eef140b Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022)
Fixes #153790

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154022
Approved by: https://github.com/Skylion007
2025-05-27 14:10:00 +00:00
2362bd4a4c [Torch][NT] Fix NestedTensor contiguous check condition. (#153237) (#153529)
Fixes #153237

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153529
Approved by: https://github.com/jbschlosser
2025-05-14 17:15:48 +00:00
2c4bc65366 [aotd] Guess tangents stride as output strides (#144579)
AOTDispatch  doing AOT backward graph preparation does not know real tangents that user will specify when runs backward.

AOTD guesses the tangents. Before - we guessed that memory format of tangents will be as memory format of corresponding outputs. And if specified tangents at runtime are not the same memory format as we guessed during compilation, AOTD does coercion (copy) to guessed memory_format

But as Horace found, there are popular use cases, where the outputs of compiled region will be in specific memory_format. E.g. in 4D tensor transposing dims 1 and 2.

https://github.com/karpathy/nanoGPT/blob/master/model.py#L57

This PR changes the logic, that AOTD expects the same "strideness" of tangents as outputs. As a result it will avoid coercion for the case of transposed dims.

Limitations:
We keep guessing memory_format for:
1/ Dynamic shapes (needs more changes)
2/ Tensor subclasses (needs more changes)

Other changes:
test_torchinductor was always creating contiguous tangents via `torch.randn()`, changing them to be `torch.randn_like()` to compare computation with the same strideness.

(E.g. for cuda float16 strideness affects numerics for fft ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144579
Approved by: https://github.com/bdhirsh
2025-03-20 15:41:36 +00:00
93e9daed54 [cuDNN][SDPA][Nested Tensor] Experimental cuDNN Nested Tensor SDPA Support (forward only) (#141178)
Disabled by default for now behind `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1`

Just wanted to get this out before starting a series of SDPA cleanup PRs---the biggest thing is we don't need the boilerplate around all of the `build_graph_and_tensors*` functions anymore as we can now use the `UID`-style referencing of tensor nodes as was done for the Conv-V8 API backend.

CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141178
Approved by: https://github.com/jbschlosser
2025-03-04 23:09:09 +00:00
fa8e3a28a7 Revert "[cuDNN][SDPA][Nested Tensor] Experimental cuDNN Nested Tensor SDPA Support (forward only) (#141178)"
This reverts commit 533b884870acd951e684e0bf551eb76904dec047.

Reverted https://github.com/pytorch/pytorch/pull/141178 on behalf of https://github.com/jeanschmidt due to Broke internal arvr signals, see D69971019. @jbschlosser please help the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/141178#issuecomment-2676317470))
2025-02-22 17:28:12 +00:00
533b884870 [cuDNN][SDPA][Nested Tensor] Experimental cuDNN Nested Tensor SDPA Support (forward only) (#141178)
Disabled by default for now behind `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1`

Just wanted to get this out before starting a series of SDPA cleanup PRs---the biggest thing is we don't need the boilerplate around all of the `build_graph_and_tensors*` functions anymore as we can now use the `UID`-style referencing of tensor nodes as was done for the Conv-V8 API backend.

CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141178
Approved by: https://github.com/jbschlosser
2025-02-21 05:22:19 +00:00
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
43496e9b90 [NJT] fix flop counter for SDPA & test (#147032)
Fixes 3 issues:
1. The test wasn't actually testing SDPA: both were checking cuda, and the inputs to SDPA were not transposed.
2. FlopCounterMode has been renamed _FlopCounterMode (and a wrapper named FlopCounterMode has been added)
3. offsets_to_list also needs to ignore the actual offset values if offsets is a meta tensor.

Differential Revision: [D69558785](https://our.internmc.facebook.com/intern/diff/D69558785)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147032
Approved by: https://github.com/jbschlosser
2025-02-13 07:14:58 +00:00
3cadce7af2 [NJT] Fix inference mode for composite implicit ops without nested-specific kernel (#146633)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146633
Approved by: https://github.com/jbschlosser
2025-02-10 16:59:48 +00:00
e01a5e9e1e Small improvements to NJT matrix multiplies (#146405)
Fixes #146404

Adds changes to the matmul and matmul_backward operation for nested jagged tensors, to support back propagation when the output is a regular strided tensor.
This required adding support for the nested matmul operation to work when the nested tensor wasn't 'self', i.e
`A@B` where `A` isn't nested but `B` is.

The operation schemas had to be updated to reflect that either input can be a strided tensor instead (and the gradient), so an extra assertion is added in an edge case where neither input is nested.

Unit tests are also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146405
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
2025-02-06 04:51:12 +00:00
1ba1b7b597 Support remaining *_like factory functions for NJT (#144889)
Fixes #144761

This PR adds NJT impls for those *_like functions that were previously missing:
* `full_like()`
* `rand_like()`
* `randint_like()`

It also fixes a bug in existing *_like functions when a new device is specified. Fix is to also transfer `offsets` / `lengths` to the new device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144889
Approved by: https://github.com/soulitzer
2025-01-27 21:33:51 +00:00
b2a0feac85 Update OSS nested tensor docs to focus on NJT (#145402)
Updated nested tensor docs to be NJT-centric (instead of NST-centric). They now include:
* High-level description of NST vs. NJT + a recommendation to use NJT
* General NJT construction / usage
* torch.compile() integration w/ dynamic shapes
* Common errors and how to fix them
* Contribution guide
* Data layout / shape information (with diagram)
* Links to more extensive tutorials involving Transformers / SDPA / FlexAttention

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145402
Approved by: https://github.com/soulitzer
2025-01-25 04:08:19 +00:00
5725462cd8 Update NJT linear_backward to return non-aliased tensor bias grad (#145399)
Fixes https://github.com/pytorch/pytorch/issues/141292

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145399
Approved by: https://github.com/jbschlosser
ghstack dependencies: #145520, #145531, #145533
2025-01-25 00:58:04 +00:00
128f3627b1 Implement backward for NJT matmul (#144587)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

This PR implements missing backward support for NJT matmul. Notably, for dense tensors, matmul dispatches to bmm. However, due to historical reasons related to NST, NJT handles matmul directly, and thus can't rely on the CompositeImplicit impl of matmul to get the derivative formula.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144587
Approved by: https://github.com/soulitzer
ghstack dependencies: #144586
2025-01-21 18:27:50 +00:00
af204135d8 Fix NJT fill.Scalar for contiguous inputs (#144586)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

This PR implements the missing `fill.Scalar` support, which works fine for contiguous inputs, but there is still some AOTAutograd debugging required to handle non-contiguous transposed NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144586
Approved by: https://github.com/soulitzer
2025-01-21 18:22:08 +00:00
805c4b597a PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145202
Approved by: https://github.com/bobrenjc93
2025-01-20 22:37:26 +00:00
b63b81410c Fix NJT frexp() to handle both outputs (#144585)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144585
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583, #144584
2025-01-18 15:59:56 +00:00
3ee531f8b9 Support NJT chunk() backward on batch dim (#144584)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144584
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583
2025-01-18 15:58:24 +00:00
a8ef423fed Fix NJT min / max backward() for non-ragged reductions (#144583)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

`value_selecting_reduction_backward()` is used in the backward for min / max, so this PR implements it for NJT. Notably, this isn't enough for reducing over the ragged dim, since that results in a dense tensor and thus NJT's torch_dispatch will not be called for this op. We need factory function support for nested ints to fix that case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144583
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582
2025-01-17 20:57:11 +00:00
5c4545f857 [BE][Easy] enable PYFMT for torch/[a-s]*/ (#138447)
Reproduce command:

```bash
ghstack checkout https://github.com/pytorch/pytorch/pull/138447
git checkout HEAD~1 torch/
lintrunner -a --take "PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138447
Approved by: https://github.com/ezyang
2024-12-23 14:04:00 +00:00
f1cbf4b1b5 Enable ruff's unused variable checking everywhere in pytorch (#136965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136965
Approved by: https://github.com/cyyever, https://github.com/albanD
2024-12-22 02:33:11 +00:00
3f99682fbd NJT linear_backward should not return inner tensor as-is (#143333)
Fixes debug=1 use-count checks https://github.com/pytorch/pytorch/actions/runs/12187808902/job/34002323481#step:22:2521

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143333
Approved by: https://github.com/jbschlosser
2024-12-18 00:15:18 +00:00
661d1f0372 [aotd] non-contiguous NestedTensor mutation in compile (#139630)
Allow mutations mutations for subclasses that are non-contiguous.

Changes:

Removing assert in collect_metadata_analysis

Main requested testcase:
Compilation of NJT.index_put()

Adding test in test_nestedtensor.py, that compiles NJT.index_put()

It is  decomposed to NJT split,unbind, which  needed additional `torch._check`, `torch._check_is_size` for NJT.unbind()  and guard_size_oblivious() usage in _meta_registrations and _inductor/lowering.py.

Special case:
If tangent is mutated outside of the graph, it does not participate in backward graph. Autograd in this case will set this tangent to zeros tensor.

We handle it separately in CompiledFunction.backward: not doing any processing for this tangent and broadcast to number of expected subclass unwrapped arguments.

disabling for dynamo 2 tests:
1/ For nested tensor - symbolic shapes issue on nested_tensor index operation that does splits [0, 0, 0] - there is a failure with "pending unbacked symints". This PR does not add more .tolist()/item() ops than it was before.

2/ As we do not fail with exception in collect_metadata_analysis new paths for dynamo started working and it started failing with smth strange that set_ in storage_offset (because of test for views) handling updates storage "cpu" -> "meta"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139630
Approved by: https://github.com/bdhirsh
2024-12-06 12:18:46 +00:00
e803a3d83a Fix reductions for NJTs with ragged_idx != 1 (#142173)
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125.

This PR:
* Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this.
* Fixes outputs across keepdim settings when reducing over ragged / batch dims.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173
Approved by: https://github.com/drisspg
2024-12-06 01:23:17 +00:00
161a2340ee Switch to using Python nested int (#141166)
Doesn't seem to noticeably slow down eager - TestNestedTensorSubclass tests with and without the PR finished in similar amounts of time (around 57s, 58s)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141166
Approved by: https://github.com/ezyang
2024-12-02 19:17:30 +00:00
c9e2b3fefe NJT: Return correct number of outputs for chunk() on the batch dim (#141604)
Old logic was completely wrong, returning `chunk_size` chunks instead of the intended number. The original test didn't catch this because `chunk_size == num_chunks` :p New OpInfo-based testing covers it though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141604
Approved by: https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392, #141506
2024-11-27 02:31:23 +00:00
43121b6f0d Adjust output NJT ragged_idx for reductions and select() (#141506)
This fixes some bugs when performing reductions / select() on dims before the ragged dim. In this case, the output NJT has a smaller number of dims, and its ragged_idx should reflect that correctly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141506
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392
2024-11-27 02:25:53 +00:00
23793cf93d NJT unsqueeze() fixes (#141392)
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736, #140161
2024-11-26 22:38:35 +00:00
9ee5d6f83c Initial NJT testing over dim type / views (#140161)
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736
2024-11-26 22:08:08 +00:00
869d629c0f Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500
2024-11-26 21:19:58 +00:00