change sparse COO comparison strategy in assert_close (#68728)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68728

This removes the ability for `assert_close` to `.coalesce()` the tensors internally. Additionally, we now also check `.sparse_dim()`. Sparse team: please make sure that is the behavior you want for all sparse COO comparisons in the future. #67796 will temporarily keep BC by always coalescing, but in the future `TestCase.assertEqual` will no longer do that.

cc nikitaved pearu cpuhrsch IvanYashchuk

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D33542996

Pulled By: mruberry

fbshipit-source-id: a8d2322c6ee1ca424e3efb14ab21787328cf28fc
This commit is contained in:
Philip Meier
2022-01-12 06:40:45 -08:00
committed by Facebook GitHub Bot
parent 8d05174def
commit 802dd2b725
3 changed files with 36 additions and 52 deletions

View File

@ -1204,37 +1204,15 @@ class TestAssertCloseSparseCOO(TestCase):
for fn in assert_close_with_inputs(actual, expected):
fn()
def test_mismatching_is_coalesced(self):
indices = (
(0, 1),
(1, 0),
)
values = (1, 2)
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
expected = actual.clone().coalesce()
def test_mismatching_sparse_dims(self):
t = torch.randn(2, 3, 4)
actual = t.to_sparse()
expected = t.to_sparse(2)
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "is_coalesced"):
with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
fn()
def test_mismatching_is_coalesced_no_check(self):
actual_indices = (
(0, 1),
(1, 0),
)
actual_values = (1, 2)
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
expected_indices = (
(0, 1, 1,),
(1, 0, 0,),
)
expected_values = (1, 1, 1)
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
fn(check_is_coalesced=False)
def test_mismatching_nnz(self):
actual_indices = (
(0, 1),

View File

@ -50,6 +50,14 @@ _DTYPE_PRECISIONS = {
torch.complex64: (1.3e-6, 1e-5),
torch.complex128: (1e-7, 1e-7),
}
# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
_DTYPE_PRECISIONS.update(
{
dtype: _DTYPE_PRECISIONS[torch.float32]
for dtype in (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32)
}
)
def default_tolerances(*inputs: Union[torch.Tensor, torch.dtype]) -> Tuple[float, float]:
@ -622,13 +630,12 @@ class TensorLikePair(Pair):
- the :attr:`~torch.Tensor.shape`,
- whether both inputs are quantized or not,
- and if they are the quantization scheme.
- and if they use the same quantization scheme.
Checks for
- :attr:`~torch.Tensor.layout`,
- :meth:`~torch.Tensor.stride`,
- :meth:`~torch.Tensor.is_coalesced`,
- :attr:`~torch.Tensor.device`, and
- :attr:`~torch.Tensor.dtype`
@ -652,15 +659,8 @@ class TensorLikePair(Pair):
if actual.layout != expected.layout:
if self.check_layout:
raise_mismatch_error("layout", actual.layout, expected.layout)
else:
if actual.layout == torch.strided and self.check_stride and actual.stride() != expected.stride():
raise_mismatch_error("stride()", actual.stride(), expected.stride())
elif (
actual.layout == torch.sparse_coo
and self.check_is_coalesced
and actual.is_coalesced() != expected.is_coalesced()
):
raise_mismatch_error("is_coalesced()", actual.is_coalesced(), expected.is_coalesced())
elif actual.layout == torch.strided and self.check_stride and actual.stride() != expected.stride():
raise_mismatch_error("stride()", actual.stride(), expected.stride())
if self.check_device and actual.device != expected.device:
raise_mismatch_error("device", actual.device, expected.device)
@ -677,7 +677,6 @@ class TensorLikePair(Pair):
- ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to
:func:`torch.promote_types`).
- ... not of the same ``layout``, they are converted to strided tensors.
- ... both sparse COO tensors but only one is coalesced, the other one is coalesced.
Args:
actual (Tensor): Actual tensor.
@ -699,9 +698,6 @@ class TensorLikePair(Pair):
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
actual = actual.to_dense() if actual.layout != torch.strided else actual
expected = expected.to_dense() if expected.layout != torch.strided else expected
elif actual.is_sparse and actual.is_coalesced() != expected.is_coalesced():
actual = actual.coalesce()
expected = expected.coalesce()
return actual, expected
@ -735,10 +731,20 @@ class TensorLikePair(Pair):
) -> None:
"""Compares sparse COO tensors by comparing
- the number of sparse dimensions,
- the number of non-zero elements (nnz) for equality,
- the indices for equality, and
- the values for closeness.
"""
if actual.sparse_dim() != expected.sparse_dim():
raise self._make_error_meta(
AssertionError,
(
f"The number of sparse dimensions in sparse COO tensors does not match: "
f"{actual.sparse_dim()} != {expected.sparse_dim()}"
),
)
if actual._nnz() != expected._nnz():
raise self._make_error_meta(
AssertionError,
@ -1031,7 +1037,6 @@ def assert_close(
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
check_is_coalesced: bool = True,
msg: Optional[str] = None,
):
r"""Asserts that ``actual`` and ``expected`` are close.
@ -1050,8 +1055,6 @@ def assert_close(
If ``actual`` and ``expected`` are sparse (either having COO or CSR layout), their strided members are
checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout,
are always checked for equality whereas the values are checked for closeness according to the definition above.
Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (if
``check_is_coalesced`` is ``True``).
If ``actual`` and ``expected`` are quantized, they are considered close if they have the same
:meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the
@ -1089,9 +1092,6 @@ def assert_close(
check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
compared.
check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
check_is_coalesced (bool): If ``True`` (default) and corresponding tensors are sparse COO, checks that both
``actual`` and ``expected`` are either coalesced or uncoalesced. If this check is disabled, tensors are
:meth:`~torch.Tensor.coalesce`'ed before being compared.
msg (Optional[str]): Optional error message to use in case a failure occurs during the comparison.
Raises:
@ -1112,8 +1112,6 @@ def assert_close(
:attr:`~torch.Tensor.device`.
AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``.
AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride.
AssertionError: If ``check_is_coalesced`` is ``True``, but corresponding sparse COO tensors are not both
either coalesced or uncoalesced.
AssertionError: If the values of corresponding tensors are not close according to the definition above.
The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching
@ -1136,6 +1134,16 @@ def assert_close(
+---------------------------+------------+----------+
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
+---------------------------+------------+----------+
| :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| other | ``0.0`` | ``0.0`` |
+---------------------------+------------+----------+
@ -1255,6 +1263,5 @@ def assert_close(
check_dtype=check_dtype,
check_layout=check_layout,
check_stride=check_stride,
check_is_coalesced=check_is_coalesced,
msg=msg,
)

View File

@ -85,7 +85,6 @@ def assert_allclose(
check_device=True,
check_dtype=False,
check_stride=False,
check_is_coalesced=False,
msg=msg or None,
)