From ce79b09415c0513800d647ed516fe38d38140bfb Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 13 Jun 2024 06:58:04 +0000 Subject: [PATCH] [CUDA][Sparse] Change comparison function of `test_sparse_semi_structured.py` and bump tolerances for `sp24_matmuls` (#128553) Minor tweak of comparison as using `assert` on `torch.allclose` prevents the mismatches from being logged. Also bump a few tolerances that seem to be causing failures on sm86/sm90 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128553 Approved by: https://github.com/jcaip --- test/test_sparse_semi_structured.py | 66 ++++++++++++++--------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 625f067da467..70aefb2349ce 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -215,7 +215,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): sparse_compile_result = model(input) # test that sparse_compile_result and dense_result are numerically close - assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3) # assert sparse and sparse_compile have the same strides, # as meta registrations may return contiguous tensors when the output is transposed # https://github.com/pytorch/pytorch/pull/114477 @@ -304,7 +304,7 @@ class TestSparseSemiStructured(TestCase): else: dense_result = torch.mm(A, B) sparse_result = torch.mm(A_sparse, B) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @inference_dtypes @parametrize_backends @@ -335,12 +335,12 @@ class TestSparseSemiStructured(TestCase): # test transpose dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8) sparse_result = torch.mm(A_sparse, B.t()) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) else: # test transpose dense_result = torch.mm(A, B.t()) sparse_result = torch.mm(A_sparse, B.t()) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @inference_dtypes @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) @@ -386,7 +386,7 @@ class TestSparseSemiStructured(TestCase): dense_result = torch.mm(A, B.t()) sparse_result = torch.mm(A, B_sparse.t()) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @inference_dtypes @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) @@ -436,7 +436,7 @@ class TestSparseSemiStructured(TestCase): else: sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) @parametrize_backends @@ -467,7 +467,7 @@ class TestSparseSemiStructured(TestCase): sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @parametrize_backends def test_values(self, backend): @@ -504,7 +504,7 @@ class TestSparseSemiStructured(TestCase): else: dense_res = torch.mm(A, B) sparse_res = torch.mm(A_sparse, B) - assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3) @inference_dtypes @parametrize_backends @@ -577,7 +577,7 @@ class TestSparseSemiStructuredTraining(TestCase): # CUTLASS reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy") - assert torch.allclose(pruned, reference_cutlass.to_dense()) + torch.testing.assert_close(pruned, reference_cutlass.to_dense()) packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) @@ -592,12 +592,12 @@ class TestSparseSemiStructuredTraining(TestCase): packed_t_cutlass, meta_t_cutlass, compressed_swizzled_bitmask) - assert torch.allclose(reference_cutlass.to_dense(), cutlass.to_dense()) + torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense()) # CUSPARSELT reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy") - assert torch.allclose(pruned, reference_cusparselt.to_dense()) + torch.testing.assert_close(pruned, reference_cusparselt.to_dense()) packed_cusparselt = torch._cslt_compress(pruned) packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) @@ -607,7 +607,7 @@ class TestSparseSemiStructuredTraining(TestCase): packed_t_cusparselt, None, compressed_swizzled_bitmask) - assert torch.allclose(reference_cusparselt.to_dense(), cusparselt.to_dense()) + torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense()) @@ -644,7 +644,7 @@ class TestSparseSemiStructuredTraining(TestCase): masked_a = a * mask ref_out = masked_a @ b sp24_out = a_sparse @ b - assert torch.allclose(ref_out, sp24_out, **atol_rtol_kw[dtype]) + torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype]) @training_dtypes @@ -675,11 +675,11 @@ class TestSparseSemiStructuredTraining(TestCase): meta_t=meta_t, compressed_swizzled_bitmask=bitmask, ) - assert torch.allclose(a_sparse.meta.view(torch.short), sparse_mask.meta) + torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta) ref_gemm = (mask_dense * a) @ b pack_gemm = a_sparse @ b - assert torch.allclose(ref_gemm, pack_gemm, **atol_rtol_kw[dtype]) + torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype]) @training_dtypes def test_pack_both_ways_id(self, dtype) -> None: @@ -692,7 +692,7 @@ class TestSparseSemiStructuredTraining(TestCase): :4 ] # Heuristic to ensure we pack the same values - assert torch.allclose( + torch.testing.assert_close( packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum() ) @@ -702,7 +702,7 @@ class TestSparseSemiStructuredTraining(TestCase): # Test A@B pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t() max_diff = (ref_gemm - pack_gemm).abs().argmax() - assert torch.allclose( + torch.testing.assert_close( ref_gemm, pack_gemm, **atol_rtol_kw[dtype] ), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})" @@ -710,7 +710,7 @@ class TestSparseSemiStructuredTraining(TestCase): pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t) max_diff = (ref_gemm - pack_gemm).abs().argmax() - assert torch.allclose( + torch.testing.assert_close( ref_gemm, pack_gemm, **atol_rtol_kw[dtype] ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" @@ -751,8 +751,8 @@ class TestSparseSemiStructuredTraining(TestCase): bitmask, ) = torch._sparse_semi_structured_tile(x) packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask) - assert torch.allclose(packed, packed2) - assert torch.allclose(packed_t, packed_t2) + torch.testing.assert_close(packed, packed2) + torch.testing.assert_close(packed_t, packed_t2) @training_dtypes def test_sp24_apply_dense(self, dtype) -> None: @@ -787,8 +787,8 @@ class TestSparseSemiStructuredTraining(TestCase): dense = torch._sparse_semi_structured_apply_dense(x, bitmask) - assert torch.allclose(dense, expected) - assert torch.allclose(sparse.to_dense(), expected) + torch.testing.assert_close(dense, expected) + torch.testing.assert_close(sparse.to_dense(), expected) @training_dtypes @@ -817,12 +817,12 @@ class TestSparseSemiStructuredTraining(TestCase): compressed_swizzled_bitmask=bitmask, ) - assert torch.allclose(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1e-1) - assert torch.allclose(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1e-1) - assert torch.allclose( - a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1e-1 + torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1) + torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1) + torch.testing.assert_close( + a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1 ) - assert torch.allclose( + torch.testing.assert_close( a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1 ) @@ -833,7 +833,7 @@ class TestSparseSemiStructuredTraining(TestCase): a_s = to_sparse_semi_structured(a) with pytest.raises(NotImplementedError): - assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) + torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) def test_sp24_matmuls_bmm(self) -> None: @@ -843,7 +843,7 @@ class TestSparseSemiStructuredTraining(TestCase): a_s = to_sparse_semi_structured(a) with pytest.raises(NotImplementedError): - assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) + torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) class TestSparseSemiStructuredCUTLASS(TestCase): """ @@ -1049,7 +1049,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype) sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @training_dtypes def test_cslt_sparse_mm_alpha(self, dtype, device): @@ -1064,7 +1064,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32)) dense_result = dense_result.to(dtype) - assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT) def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): @@ -1080,7 +1080,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu()) dense_result = dense_result.to(out_dtype) - assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) @parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS)) @inference_dtypes @@ -1098,7 +1098,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) dense_result = dense_result.to(dtype) - assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) @inference_dtypes def test_cslt_sparse_mm_search(self, device, dtype):