add support for sparse tensors in torch.testing.assert_close (#58844)

Summary:
This adds support for sparse tensors the same way `torch.testing._internal.common_utils.TestCase.assertEqual` does:

5c7dace309/torch/testing/_internal/common_utils.py (L1287-L1313)

- Tensors are coalesced before comparison.
- Indices and values are compared individually.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58844

Reviewed By: zou3519

Differential Revision: D29160250

Pulled By: mruberry

fbshipit-source-id: b0955656c2c7ff3db37a1367427ca54ca14f2e87
This commit is contained in:
Philip Meier
2021-06-23 21:57:51 -07:00
committed by Facebook GitHub Bot
parent 80f40b172f
commit 6ea22672c4
2 changed files with 319 additions and 27 deletions

View File

@ -11,7 +11,7 @@ from typing import Any, Callable, Iterator, List, Tuple
import torch
from torch.testing._internal.common_utils import \
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
(IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
from torch.testing._internal.framework_utils import calculate_shards
from torch.testing._internal.common_device_type import \
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
@ -817,14 +817,6 @@ def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
class TestAssertClose(TestCase):
def test_sparse_support(self):
actual = torch.empty(())
expected = torch.sparse_coo_tensor(size=())
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn()
def test_quantized_support(self):
val = 1
actual = torch.tensor([val], dtype=torch.int32)
@ -859,6 +851,25 @@ class TestAssertClose(TestCase):
with self.assertRaisesRegex(AssertionError, "shape"):
fn()
@unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
def test_unknown_layout(self):
actual = torch.empty((2, 2))
expected = actual.to_mkldnn()
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn()
def test_mismatching_layout(self):
strided = torch.empty((2, 2))
sparse_coo = strided.to_sparse()
sparse_csr = strided.to_sparse_csr()
for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "layout"):
fn()
def test_mismatching_dtype(self):
actual = torch.empty((), dtype=torch.float)
expected = actual.clone().to(torch.int)
@ -1158,5 +1169,180 @@ class TestAssertCloseComplex(TestCase):
fn()
class TestAssertCloseSparseCOO(TestCase):
def test_matching_coalesced(self):
indices = (
(0, 1),
(1, 0),
)
values = (1, 2)
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
expected = actual.clone()
for fn in assert_close_with_inputs(actual, expected):
fn()
def test_matching_uncoalesced(self):
indices = (
(0, 1),
(1, 0),
)
values = (1, 2)
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
expected = actual.clone()
for fn in assert_close_with_inputs(actual, expected):
fn()
def test_mismatching_is_coalesced(self):
indices = (
(0, 1),
(1, 0),
)
values = (1, 2)
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
expected = actual.clone().coalesce()
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "is_coalesced"):
fn()
def test_mismatching_is_coalesced_no_check(self):
actual_indices = (
(0, 1),
(1, 0),
)
actual_values = (1, 2)
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
expected_indices = (
(0, 1, 1,),
(1, 0, 0,),
)
expected_values = (1, 1, 1)
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
fn(check_is_coalesced=False)
def test_mismatching_nnz(self):
actual_indices = (
(0, 1),
(1, 0),
)
actual_values = (1, 2)
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
expected_indices = (
(0, 1, 1,),
(1, 0, 0,),
)
expected_values = (1, 1, 1)
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("number of specified values")):
fn()
def test_mismatching_indices_msg(self):
actual_indices = (
(0, 1),
(1, 0),
)
actual_values = (1, 2)
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
expected_indices = (
(0, 1),
(1, 1),
)
expected_values = (1, 2)
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the indices")):
fn()
def test_mismatching_values_msg(self):
actual_indices = (
(0, 1),
(1, 0),
)
actual_values = (1, 2)
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
expected_indices = (
(0, 1),
(1, 0),
)
expected_values = (1, 3)
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
fn()
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
class TestAssertCloseSparseCSR(TestCase):
def test_matching(self):
crow_indices = (0, 1, 2)
col_indices = (1, 0)
values = (1, 2)
actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
# TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
expected = torch.sparse_csr_tensor(
actual.crow_indices(), actual.col_indices(), actual.values(), size=actual.size(), device=actual.device
)
for fn in assert_close_with_inputs(actual, expected):
fn()
def test_mismatching_crow_indices_msg(self):
actual_crow_indices = (0, 1, 2)
actual_col_indices = (1, 0)
actual_values = (1, 2)
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
expected_crow_indices = (0, 2, 2)
expected_col_indices = actual_col_indices
expected_values = actual_values
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the crow_indices")):
fn()
def test_mismatching_col_indices_msg(self):
actual_crow_indices = (0, 1, 2)
actual_col_indices = (1, 0)
actual_values = (1, 2)
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
expected_crow_indices = actual_crow_indices
expected_col_indices = (1, 1)
expected_values = actual_values
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the col_indices")):
fn()
def test_mismatching_values_msg(self):
actual_crow_indices = (0, 1, 2)
actual_col_indices = (1, 0)
actual_values = (1, 2)
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
expected_crow_indices = actual_crow_indices
expected_col_indices = actual_col_indices
expected_values = (1, 3)
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
fn()
if __name__ == '__main__':
run_tests()