mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Strict shape checking for NJTs with TestCase.assertEqual() (#131898)
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal. This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal. Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131898 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
58f76bc301
commit
d53b11bb6e
@ -3765,11 +3765,10 @@ class TestCase(expecttest.TestCase):
|
||||
elif isinstance(x, Sequence) and isinstance(y, torch.Tensor):
|
||||
x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
|
||||
|
||||
# If x or y are tensors and nested then we unbind them to a list of tensors this should allow us to compare
|
||||
# a nested tensor to a nested tensor and a nested tensor to a list of expected tensors
|
||||
if isinstance(x, torch.Tensor) and x.is_nested:
|
||||
# unbind NSTs to compare them; don't do this for NJTs
|
||||
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided:
|
||||
x = x.unbind()
|
||||
if isinstance(y, torch.Tensor) and y.is_nested:
|
||||
if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided:
|
||||
y = y.unbind()
|
||||
|
||||
error_metas = not_close_error_metas(
|
||||
|
Reference in New Issue
Block a user