mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add NestedTensor dispatch for _is_any_true/_is_all_true (#162096)
Fixes: https://github.com/pytorch/pytorch/issues/161818 ### Summary Add NestedTensor support for `_is_any_true` and `_is_all_true`. ### Changes - Register dispatch for `aten._is_any_true.default` and `aten._is_all_true.default` - Add CPU tests: - `test_is_any_true_jagged`: dispatch_matches_values_buffer, all_false_returns_false, one_true_returns_true - `test_is_all_true_jagged`: dispatch_matches_values_buffer, all_true_returns_true, any_false_returns_false ### Testing Before Fix: `pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v` Output: ``` FAILED [0.0129s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu - NotImplementedError: aten._is_all_true.default FAILED [0.0007s] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu - NotImplementedError: aten._is_any_true.default ``` After Fix: `pytest -q test/test_nestedtensor.py -k "test_is_any_true_jagged or test_is_all_true_jagged" -v` Output: ``` Running 2 items in this shard test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_all_true_jagged_cpu PASSED [0.0277s] [ 50%] test/test_nestedtensor.py::TestNestedTensorDeviceTypeCPU::test_is_any_true_jagged_cpu PASSED [0.0013s] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162096 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
d0086708dd
commit
fd785b1762
@ -1334,6 +1334,82 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
|
||||
lambda: func(nt_noncontiguous),
|
||||
)
|
||||
|
||||
def test_is_any_true_jagged(self, device):
|
||||
B, Fin = 2, 6
|
||||
start = torch.zeros(B, dtype=torch.int64, device=device)
|
||||
lengths = torch.tensor([3, 2], dtype=torch.int64, device=device)
|
||||
|
||||
# NestedTensor reduction should operate on same data as .values().
|
||||
with self.subTest("dispatch_matches_values_buffer"):
|
||||
cond = torch.tensor(
|
||||
[
|
||||
[True, False, False, True, True, False],
|
||||
[False, False, True, False, False, False],
|
||||
],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
nt = torch.nested.narrow(
|
||||
cond, dim=1, start=start, length=lengths, layout=torch.jagged
|
||||
)
|
||||
out_nt = torch.ops.aten._is_any_true.default(nt).item()
|
||||
out_vals = torch.ops.aten._is_any_true.default(nt.values()).item()
|
||||
self.assertEqual(out_nt, out_vals)
|
||||
|
||||
# Verify jagged boolean behavior.
|
||||
with self.subTest("all_false_returns_false"):
|
||||
cond_false = torch.zeros(B, Fin, dtype=torch.bool, device=device)
|
||||
nt_false = torch.nested.narrow(
|
||||
cond_false, dim=1, start=start, length=lengths, layout=torch.jagged
|
||||
)
|
||||
self.assertFalse(torch.ops.aten._is_any_true.default(nt_false).item())
|
||||
|
||||
with self.subTest("one_true_returns_true"):
|
||||
cond_mixed = torch.zeros(B, Fin, dtype=torch.bool, device=device)
|
||||
cond_mixed[0, 0] = True
|
||||
nt_mixed = torch.nested.narrow(
|
||||
cond_mixed, dim=1, start=start, length=lengths, layout=torch.jagged
|
||||
)
|
||||
self.assertTrue(torch.ops.aten._is_any_true.default(nt_mixed).item())
|
||||
|
||||
def test_is_all_true_jagged(self, device):
|
||||
B, Fin = 2, 6
|
||||
start = torch.zeros(B, dtype=torch.int64, device=device)
|
||||
lengths = torch.tensor([3, 2], dtype=torch.int64, device=device)
|
||||
|
||||
# NestedTensor reduction should operate on same data as .values().
|
||||
with self.subTest("dispatch_matches_values_buffer"):
|
||||
cond = torch.tensor(
|
||||
[
|
||||
[True, True, True, False, False, False],
|
||||
[True, True, False, False, False, False],
|
||||
],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
nt = torch.nested.narrow(
|
||||
cond, dim=1, start=start, length=lengths, layout=torch.jagged
|
||||
)
|
||||
out_nt = torch.ops.aten._is_all_true.default(nt).item()
|
||||
out_vals = torch.ops.aten._is_all_true.default(nt.values()).item()
|
||||
self.assertEqual(out_nt, out_vals)
|
||||
|
||||
# Verify jagged boolean behavior.
|
||||
with self.subTest("all_true_returns_true"):
|
||||
cond_true = torch.ones(B, Fin, dtype=torch.bool, device=device)
|
||||
nt_true = torch.nested.narrow(
|
||||
cond_true, dim=1, start=start, length=lengths, layout=torch.jagged
|
||||
)
|
||||
self.assertTrue(torch.ops.aten._is_all_true.default(nt_true).item())
|
||||
|
||||
with self.subTest("any_false_returns_false"):
|
||||
cond_mixed = torch.ones(B, Fin, dtype=torch.bool, device=device)
|
||||
cond_mixed[0, 1] = False
|
||||
nt_mixed = torch.nested.narrow(
|
||||
cond_mixed, dim=1, start=start, length=lengths, layout=torch.jagged
|
||||
)
|
||||
self.assertFalse(torch.ops.aten._is_all_true.default(nt_mixed).item())
|
||||
|
||||
@parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")])
|
||||
def test_binary_ops_with_scalar(self, device, func):
|
||||
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
||||
|
@ -2131,6 +2131,19 @@ def all_any_max_min_default(func, *args, **kwargs):
|
||||
return func(inp._values, **new_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
[torch.ops.aten._is_all_true.default, torch.ops.aten._is_any_true.default],
|
||||
"self: jt_all",
|
||||
)
|
||||
def _is_true_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
return func(inp._values)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?")
|
||||
def min_dim(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
|
Reference in New Issue
Block a user