change sparse COO comparison strategy in assert_close (#68728)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68728

This removes the ability for `assert_close` to `.coalesce()` the tensors internally. Additionally, we now also check `.sparse_dim()`. Sparse team: please make sure that is the behavior you want for all sparse COO comparisons in the future. #67796 will temporarily keep BC by always coalescing, but in the future `TestCase.assertEqual` will no longer do that.

cc nikitaved pearu cpuhrsch IvanYashchuk

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D33542996

Pulled By: mruberry

fbshipit-source-id: a8d2322c6ee1ca424e3efb14ab21787328cf28fc
This commit is contained in:
Philip Meier
2022-01-12 06:40:45 -08:00
committed by Facebook GitHub Bot
parent 8d05174def
commit 802dd2b725
3 changed files with 36 additions and 52 deletions

View File

@ -1204,37 +1204,15 @@ class TestAssertCloseSparseCOO(TestCase):
for fn in assert_close_with_inputs(actual, expected):
fn()
def test_mismatching_is_coalesced(self):
indices = (
(0, 1),
(1, 0),
)
values = (1, 2)
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
expected = actual.clone().coalesce()
def test_mismatching_sparse_dims(self):
t = torch.randn(2, 3, 4)
actual = t.to_sparse()
expected = t.to_sparse(2)
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "is_coalesced"):
with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
fn()
def test_mismatching_is_coalesced_no_check(self):
actual_indices = (
(0, 1),
(1, 0),
)
actual_values = (1, 2)
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
expected_indices = (
(0, 1, 1,),
(1, 0, 0,),
)
expected_values = (1, 1, 1)
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
for fn in assert_close_with_inputs(actual, expected):
fn(check_is_coalesced=False)
def test_mismatching_nnz(self):
actual_indices = (
(0, 1),