Add NestedTensor dispatch for _is_any_true/_is_all_true (#162096)

Fixes: https://github.com/pytorch/pytorch/issues/161818

### Summary
Add NestedTensor support for `_is_any_true` and `_is_all_true`.

### Changes
- Register dispatch for `aten._is_any_true.default` and
  `aten._is_all_true.default`
- Add CPU tests:
  - `test_is_any_true_jagged`: dispatch_matches_values_buffer,
    all_false_returns_false, one_true_returns_true
  - `test_is_all_true_jagged`: dispatch_matches_values_buffer,
    all_true_returns_true, any_false_returns_false

### Testing

Before Fix:

`pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v`

Output:
```
FAILED [0.0129s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu - NotImplementedError: aten._is_all_true.default
FAILED [0.0007s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu - NotImplementedError: aten._is_any_true.default
```

After Fix:

`pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v`

Output:

```
Running 2 items in this shard

test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu PASSED [0.0277s]                                                                                                                               [ 50%]
test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu PASSED [0.0013s]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162096
Approved by: https://github.com/jbschlosser
This commit is contained in:
adabeyta
2025-09-22 20:22:40 +00:00
committed by PyTorch MergeBot
parent d0086708dd
commit fd785b1762
2 changed files with 89 additions and 0 deletions

View File

@ -1334,6 +1334,82 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
lambda: func(nt_noncontiguous),
)
def test_is_any_true_jagged(self, device):
B, Fin = 2, 6
start = torch.zeros(B, dtype=torch.int64, device=device)
lengths = torch.tensor([3, 2], dtype=torch.int64, device=device)
# NestedTensor reduction should operate on same data as .values().
with self.subTest("dispatch_matches_values_buffer"):
cond = torch.tensor(
[
[True, False, False, True, True, False],
[False, False, True, False, False, False],
],
dtype=torch.bool,
device=device,
)
nt = torch.nested.narrow(
cond, dim=1, start=start, length=lengths, layout=torch.jagged
)
out_nt = torch.ops.aten._is_any_true.default(nt).item()
out_vals = torch.ops.aten._is_any_true.default(nt.values()).item()
self.assertEqual(out_nt, out_vals)
# Verify jagged boolean behavior.
with self.subTest("all_false_returns_false"):
cond_false = torch.zeros(B, Fin, dtype=torch.bool, device=device)
nt_false = torch.nested.narrow(
cond_false, dim=1, start=start, length=lengths, layout=torch.jagged
)
self.assertFalse(torch.ops.aten._is_any_true.default(nt_false).item())
with self.subTest("one_true_returns_true"):
cond_mixed = torch.zeros(B, Fin, dtype=torch.bool, device=device)
cond_mixed[0, 0] = True
nt_mixed = torch.nested.narrow(
cond_mixed, dim=1, start=start, length=lengths, layout=torch.jagged
)
self.assertTrue(torch.ops.aten._is_any_true.default(nt_mixed).item())
def test_is_all_true_jagged(self, device):
B, Fin = 2, 6
start = torch.zeros(B, dtype=torch.int64, device=device)
lengths = torch.tensor([3, 2], dtype=torch.int64, device=device)
# NestedTensor reduction should operate on same data as .values().
with self.subTest("dispatch_matches_values_buffer"):
cond = torch.tensor(
[
[True, True, True, False, False, False],
[True, True, False, False, False, False],
],
dtype=torch.bool,
device=device,
)
nt = torch.nested.narrow(
cond, dim=1, start=start, length=lengths, layout=torch.jagged
)
out_nt = torch.ops.aten._is_all_true.default(nt).item()
out_vals = torch.ops.aten._is_all_true.default(nt.values()).item()
self.assertEqual(out_nt, out_vals)
# Verify jagged boolean behavior.
with self.subTest("all_true_returns_true"):
cond_true = torch.ones(B, Fin, dtype=torch.bool, device=device)
nt_true = torch.nested.narrow(
cond_true, dim=1, start=start, length=lengths, layout=torch.jagged
)
self.assertTrue(torch.ops.aten._is_all_true.default(nt_true).item())
with self.subTest("any_false_returns_false"):
cond_mixed = torch.ones(B, Fin, dtype=torch.bool, device=device)
cond_mixed[0, 1] = False
nt_mixed = torch.nested.narrow(
cond_mixed, dim=1, start=start, length=lengths, layout=torch.jagged
)
self.assertFalse(torch.ops.aten._is_all_true.default(nt_mixed).item())
@parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")])
def test_binary_ops_with_scalar(self, device, func):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(

View File

@ -2131,6 +2131,19 @@ def all_any_max_min_default(func, *args, **kwargs):
return func(inp._values, **new_kwargs)
@register_jagged_func(
[torch.ops.aten._is_all_true.default, torch.ops.aten._is_any_true.default],
"self: jt_all",
)
def _is_true_default(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
return func(inp._values)
@register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?")
def min_dim(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]