Strict shape checking for NJTs with TestCase.assertEqual() (#131898)

**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131898
Approved by: https://github.com/soulitzer
This commit is contained in:
Joel Schlosser
2024-07-26 14:39:24 -04:00
committed by PyTorch MergeBot
parent 58f76bc301
commit d53b11bb6e
4 changed files with 43 additions and 14 deletions

View File

@ -3765,11 +3765,10 @@ class TestCase(expecttest.TestCase):
elif isinstance(x, Sequence) and isinstance(y, torch.Tensor):
x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
# If x or y are tensors and nested then we unbind them to a list of tensors this should allow us to compare
# a nested tensor to a nested tensor and a nested tensor to a list of expected tensors
if isinstance(x, torch.Tensor) and x.is_nested:
# unbind NSTs to compare them; don't do this for NJTs
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided:
x = x.unbind()
if isinstance(y, torch.Tensor) and y.is_nested:
if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided:
y = y.unbind()
error_metas = not_close_error_metas(