[CUDA][Sparse] Change comparison function of test_sparse_semi_structured.py and bump tolerances for sp24_matmuls (#128553)

Minor tweak of comparison as using `assert` on `torch.allclose` prevents the mismatches from being logged. Also bump a few tolerances that seem to be causing failures on sm86/sm90

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128553
Approved by: https://github.com/jcaip
This commit is contained in:
Eddie Yan
2024-06-13 06:58:04 +00:00
committed by PyTorch MergeBot
parent 0678742924
commit ce79b09415

View File

@ -215,7 +215,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
sparse_compile_result = model(input)
# test that sparse_compile_result and dense_result are numerically close
assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
# assert sparse and sparse_compile have the same strides,
# as meta registrations may return contiguous tensors when the output is transposed
# https://github.com/pytorch/pytorch/pull/114477
@ -304,7 +304,7 @@ class TestSparseSemiStructured(TestCase):
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize_backends
@ -335,12 +335,12 @@ class TestSparseSemiStructured(TestCase):
# test transpose
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
# test transpose
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@ -386,7 +386,7 @@ class TestSparseSemiStructured(TestCase):
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A, B_sparse.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@ -436,7 +436,7 @@ class TestSparseSemiStructured(TestCase):
else:
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize_backends
@ -467,7 +467,7 @@ class TestSparseSemiStructured(TestCase):
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize_backends
def test_values(self, backend):
@ -504,7 +504,7 @@ class TestSparseSemiStructured(TestCase):
else:
dense_res = torch.mm(A, B)
sparse_res = torch.mm(A_sparse, B)
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize_backends
@ -577,7 +577,7 @@ class TestSparseSemiStructuredTraining(TestCase):
# CUTLASS
reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
assert torch.allclose(pruned, reference_cutlass.to_dense())
torch.testing.assert_close(pruned, reference_cutlass.to_dense())
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
@ -592,12 +592,12 @@ class TestSparseSemiStructuredTraining(TestCase):
packed_t_cutlass,
meta_t_cutlass,
compressed_swizzled_bitmask)
assert torch.allclose(reference_cutlass.to_dense(), cutlass.to_dense())
torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense())
# CUSPARSELT
reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
algorithm="largest_abs_values_greedy")
assert torch.allclose(pruned, reference_cusparselt.to_dense())
torch.testing.assert_close(pruned, reference_cusparselt.to_dense())
packed_cusparselt = torch._cslt_compress(pruned)
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
@ -607,7 +607,7 @@ class TestSparseSemiStructuredTraining(TestCase):
packed_t_cusparselt,
None,
compressed_swizzled_bitmask)
assert torch.allclose(reference_cusparselt.to_dense(), cusparselt.to_dense())
torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense())
@ -644,7 +644,7 @@ class TestSparseSemiStructuredTraining(TestCase):
masked_a = a * mask
ref_out = masked_a @ b
sp24_out = a_sparse @ b
assert torch.allclose(ref_out, sp24_out, **atol_rtol_kw[dtype])
torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype])
@training_dtypes
@ -675,11 +675,11 @@ class TestSparseSemiStructuredTraining(TestCase):
meta_t=meta_t,
compressed_swizzled_bitmask=bitmask,
)
assert torch.allclose(a_sparse.meta.view(torch.short), sparse_mask.meta)
torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta)
ref_gemm = (mask_dense * a) @ b
pack_gemm = a_sparse @ b
assert torch.allclose(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
@training_dtypes
def test_pack_both_ways_id(self, dtype) -> None:
@ -692,7 +692,7 @@ class TestSparseSemiStructuredTraining(TestCase):
:4
]
# Heuristic to ensure we pack the same values
assert torch.allclose(
torch.testing.assert_close(
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
)
@ -702,7 +702,7 @@ class TestSparseSemiStructuredTraining(TestCase):
# Test A@B
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
max_diff = (ref_gemm - pack_gemm).abs().argmax()
assert torch.allclose(
torch.testing.assert_close(
ref_gemm, pack_gemm,
**atol_rtol_kw[dtype]
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
@ -710,7 +710,7 @@ class TestSparseSemiStructuredTraining(TestCase):
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
max_diff = (ref_gemm - pack_gemm).abs().argmax()
assert torch.allclose(
torch.testing.assert_close(
ref_gemm, pack_gemm,
**atol_rtol_kw[dtype]
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
@ -751,8 +751,8 @@ class TestSparseSemiStructuredTraining(TestCase):
bitmask,
) = torch._sparse_semi_structured_tile(x)
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
assert torch.allclose(packed, packed2)
assert torch.allclose(packed_t, packed_t2)
torch.testing.assert_close(packed, packed2)
torch.testing.assert_close(packed_t, packed_t2)
@training_dtypes
def test_sp24_apply_dense(self, dtype) -> None:
@ -787,8 +787,8 @@ class TestSparseSemiStructuredTraining(TestCase):
dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
assert torch.allclose(dense, expected)
assert torch.allclose(sparse.to_dense(), expected)
torch.testing.assert_close(dense, expected)
torch.testing.assert_close(sparse.to_dense(), expected)
@training_dtypes
@ -817,12 +817,12 @@ class TestSparseSemiStructuredTraining(TestCase):
compressed_swizzled_bitmask=bitmask,
)
assert torch.allclose(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1e-1)
assert torch.allclose(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1e-1)
assert torch.allclose(
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1e-1
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1)
torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1)
torch.testing.assert_close(
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1
)
assert torch.allclose(
torch.testing.assert_close(
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
)
@ -833,7 +833,7 @@ class TestSparseSemiStructuredTraining(TestCase):
a_s = to_sparse_semi_structured(a)
with pytest.raises(NotImplementedError):
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
def test_sp24_matmuls_bmm(self) -> None:
@ -843,7 +843,7 @@ class TestSparseSemiStructuredTraining(TestCase):
a_s = to_sparse_semi_structured(a)
with pytest.raises(NotImplementedError):
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
class TestSparseSemiStructuredCUTLASS(TestCase):
"""
@ -1049,7 +1049,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@training_dtypes
def test_cslt_sparse_mm_alpha(self, dtype, device):
@ -1064,7 +1064,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
@ -1080,7 +1080,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
dense_result = dense_result.to(out_dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
@inference_dtypes
@ -1098,7 +1098,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
def test_cslt_sparse_mm_search(self, device, dtype):