mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add support for sparse tensors in torch.testing.assert_close
(#58844)
Summary:
This adds support for sparse tensors the same way `torch.testing._internal.common_utils.TestCase.assertEqual` does:
5c7dace309/torch/testing/_internal/common_utils.py (L1287-L1313)
- Tensors are coalesced before comparison.
- Indices and values are compared individually.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58844
Reviewed By: zou3519
Differential Revision: D29160250
Pulled By: mruberry
fbshipit-source-id: b0955656c2c7ff3db37a1367427ca54ca14f2e87
This commit is contained in:
committed by
Facebook GitHub Bot
parent
80f40b172f
commit
6ea22672c4
@ -11,7 +11,7 @@ from typing import Any, Callable, Iterator, List, Tuple
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
|
||||
(IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
|
||||
from torch.testing._internal.framework_utils import calculate_shards
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
|
||||
@ -817,14 +817,6 @@ def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
|
||||
|
||||
|
||||
class TestAssertClose(TestCase):
|
||||
def test_sparse_support(self):
|
||||
actual = torch.empty(())
|
||||
expected = torch.sparse_coo_tensor(size=())
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn()
|
||||
|
||||
def test_quantized_support(self):
|
||||
val = 1
|
||||
actual = torch.tensor([val], dtype=torch.int32)
|
||||
@ -859,6 +851,25 @@ class TestAssertClose(TestCase):
|
||||
with self.assertRaisesRegex(AssertionError, "shape"):
|
||||
fn()
|
||||
|
||||
@unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
|
||||
def test_unknown_layout(self):
|
||||
actual = torch.empty((2, 2))
|
||||
expected = actual.to_mkldnn()
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn()
|
||||
|
||||
def test_mismatching_layout(self):
|
||||
strided = torch.empty((2, 2))
|
||||
sparse_coo = strided.to_sparse()
|
||||
sparse_csr = strided.to_sparse_csr()
|
||||
|
||||
for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "layout"):
|
||||
fn()
|
||||
|
||||
def test_mismatching_dtype(self):
|
||||
actual = torch.empty((), dtype=torch.float)
|
||||
expected = actual.clone().to(torch.int)
|
||||
@ -1158,5 +1169,180 @@ class TestAssertCloseComplex(TestCase):
|
||||
fn()
|
||||
|
||||
|
||||
class TestAssertCloseSparseCOO(TestCase):
|
||||
def test_matching_coalesced(self):
|
||||
indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
values = (1, 2)
|
||||
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
|
||||
expected = actual.clone()
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
fn()
|
||||
|
||||
def test_matching_uncoalesced(self):
|
||||
indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
values = (1, 2)
|
||||
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
|
||||
expected = actual.clone()
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
fn()
|
||||
|
||||
def test_mismatching_is_coalesced(self):
|
||||
indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
values = (1, 2)
|
||||
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
|
||||
expected = actual.clone().coalesce()
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "is_coalesced"):
|
||||
fn()
|
||||
|
||||
def test_mismatching_is_coalesced_no_check(self):
|
||||
actual_indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
actual_values = (1, 2)
|
||||
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
|
||||
|
||||
expected_indices = (
|
||||
(0, 1, 1,),
|
||||
(1, 0, 0,),
|
||||
)
|
||||
expected_values = (1, 1, 1)
|
||||
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
fn(check_is_coalesced=False)
|
||||
|
||||
def test_mismatching_nnz(self):
|
||||
actual_indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
actual_values = (1, 2)
|
||||
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
|
||||
|
||||
expected_indices = (
|
||||
(0, 1, 1,),
|
||||
(1, 0, 0,),
|
||||
)
|
||||
expected_values = (1, 1, 1)
|
||||
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("number of specified values")):
|
||||
fn()
|
||||
|
||||
def test_mismatching_indices_msg(self):
|
||||
actual_indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
actual_values = (1, 2)
|
||||
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
|
||||
|
||||
expected_indices = (
|
||||
(0, 1),
|
||||
(1, 1),
|
||||
)
|
||||
expected_values = (1, 2)
|
||||
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the indices")):
|
||||
fn()
|
||||
|
||||
def test_mismatching_values_msg(self):
|
||||
actual_indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
actual_values = (1, 2)
|
||||
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
|
||||
|
||||
expected_indices = (
|
||||
(0, 1),
|
||||
(1, 0),
|
||||
)
|
||||
expected_values = (1, 3)
|
||||
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
|
||||
fn()
|
||||
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
|
||||
class TestAssertCloseSparseCSR(TestCase):
|
||||
def test_matching(self):
|
||||
crow_indices = (0, 1, 2)
|
||||
col_indices = (1, 0)
|
||||
values = (1, 2)
|
||||
actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
|
||||
# TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
|
||||
expected = torch.sparse_csr_tensor(
|
||||
actual.crow_indices(), actual.col_indices(), actual.values(), size=actual.size(), device=actual.device
|
||||
)
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
fn()
|
||||
|
||||
def test_mismatching_crow_indices_msg(self):
|
||||
actual_crow_indices = (0, 1, 2)
|
||||
actual_col_indices = (1, 0)
|
||||
actual_values = (1, 2)
|
||||
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
|
||||
|
||||
expected_crow_indices = (0, 2, 2)
|
||||
expected_col_indices = actual_col_indices
|
||||
expected_values = actual_values
|
||||
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the crow_indices")):
|
||||
fn()
|
||||
|
||||
def test_mismatching_col_indices_msg(self):
|
||||
actual_crow_indices = (0, 1, 2)
|
||||
actual_col_indices = (1, 0)
|
||||
actual_values = (1, 2)
|
||||
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
|
||||
|
||||
expected_crow_indices = actual_crow_indices
|
||||
expected_col_indices = (1, 1)
|
||||
expected_values = actual_values
|
||||
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the col_indices")):
|
||||
fn()
|
||||
|
||||
def test_mismatching_values_msg(self):
|
||||
actual_crow_indices = (0, 1, 2)
|
||||
actual_col_indices = (1, 0)
|
||||
actual_values = (1, 2)
|
||||
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
|
||||
|
||||
expected_crow_indices = actual_crow_indices
|
||||
expected_col_indices = actual_col_indices
|
||||
expected_values = (1, 3)
|
||||
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
|
||||
fn()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user