mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA][Sparse] Change comparison function of test_sparse_semi_structured.py and bump tolerances for sp24_matmuls (#128553)
Minor tweak of comparison as using `assert` on `torch.allclose` prevents the mismatches from being logged. Also bump a few tolerances that seem to be causing failures on sm86/sm90 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128553 Approved by: https://github.com/jcaip
This commit is contained in:
committed by
PyTorch MergeBot
parent
0678742924
commit
ce79b09415
@ -215,7 +215,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
sparse_compile_result = model(input)
|
||||
|
||||
# test that sparse_compile_result and dense_result are numerically close
|
||||
assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
|
||||
# assert sparse and sparse_compile have the same strides,
|
||||
# as meta registrations may return contiguous tensors when the output is transposed
|
||||
# https://github.com/pytorch/pytorch/pull/114477
|
||||
@ -304,7 +304,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
else:
|
||||
dense_result = torch.mm(A, B)
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@ -335,12 +335,12 @@ class TestSparseSemiStructured(TestCase):
|
||||
# test transpose
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
# test transpose
|
||||
dense_result = torch.mm(A, B.t())
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@ -386,7 +386,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
dense_result = torch.mm(A, B.t())
|
||||
sparse_result = torch.mm(A, B_sparse.t())
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@ -436,7 +436,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
else:
|
||||
sparse_result = model(input)
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
@parametrize_backends
|
||||
@ -467,7 +467,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
sparse_result = model(input)
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize_backends
|
||||
def test_values(self, backend):
|
||||
@ -504,7 +504,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
else:
|
||||
dense_res = torch.mm(A, B)
|
||||
sparse_res = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@ -577,7 +577,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
# CUTLASS
|
||||
reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
|
||||
assert torch.allclose(pruned, reference_cutlass.to_dense())
|
||||
torch.testing.assert_close(pruned, reference_cutlass.to_dense())
|
||||
|
||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
||||
@ -592,12 +592,12 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
packed_t_cutlass,
|
||||
meta_t_cutlass,
|
||||
compressed_swizzled_bitmask)
|
||||
assert torch.allclose(reference_cutlass.to_dense(), cutlass.to_dense())
|
||||
torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense())
|
||||
|
||||
# CUSPARSELT
|
||||
reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
|
||||
algorithm="largest_abs_values_greedy")
|
||||
assert torch.allclose(pruned, reference_cusparselt.to_dense())
|
||||
torch.testing.assert_close(pruned, reference_cusparselt.to_dense())
|
||||
|
||||
packed_cusparselt = torch._cslt_compress(pruned)
|
||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||
@ -607,7 +607,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
packed_t_cusparselt,
|
||||
None,
|
||||
compressed_swizzled_bitmask)
|
||||
assert torch.allclose(reference_cusparselt.to_dense(), cusparselt.to_dense())
|
||||
torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense())
|
||||
|
||||
|
||||
|
||||
@ -644,7 +644,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
masked_a = a * mask
|
||||
ref_out = masked_a @ b
|
||||
sp24_out = a_sparse @ b
|
||||
assert torch.allclose(ref_out, sp24_out, **atol_rtol_kw[dtype])
|
||||
torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype])
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@ -675,11 +675,11 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
assert torch.allclose(a_sparse.meta.view(torch.short), sparse_mask.meta)
|
||||
torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta)
|
||||
|
||||
ref_gemm = (mask_dense * a) @ b
|
||||
pack_gemm = a_sparse @ b
|
||||
assert torch.allclose(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
|
||||
torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
|
||||
|
||||
@training_dtypes
|
||||
def test_pack_both_ways_id(self, dtype) -> None:
|
||||
@ -692,7 +692,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
:4
|
||||
]
|
||||
# Heuristic to ensure we pack the same values
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
|
||||
)
|
||||
|
||||
@ -702,7 +702,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
# Test A@B
|
||||
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
|
||||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
@ -710,7 +710,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
|
||||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
@ -751,8 +751,8 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
bitmask,
|
||||
) = torch._sparse_semi_structured_tile(x)
|
||||
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
||||
assert torch.allclose(packed, packed2)
|
||||
assert torch.allclose(packed_t, packed_t2)
|
||||
torch.testing.assert_close(packed, packed2)
|
||||
torch.testing.assert_close(packed_t, packed_t2)
|
||||
|
||||
@training_dtypes
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
@ -787,8 +787,8 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
|
||||
|
||||
assert torch.allclose(dense, expected)
|
||||
assert torch.allclose(sparse.to_dense(), expected)
|
||||
torch.testing.assert_close(dense, expected)
|
||||
torch.testing.assert_close(sparse.to_dense(), expected)
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@ -817,12 +817,12 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1e-1)
|
||||
assert torch.allclose(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1e-1)
|
||||
assert torch.allclose(
|
||||
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1e-1
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1)
|
||||
torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1)
|
||||
torch.testing.assert_close(
|
||||
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1
|
||||
)
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
|
||||
)
|
||||
|
||||
@ -833,7 +833,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
a_s = to_sparse_semi_structured(a)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
@ -843,7 +843,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
a_s = to_sparse_semi_structured(a)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
"""
|
||||
@ -1049,7 +1049,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
|
||||
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@training_dtypes
|
||||
def test_cslt_sparse_mm_alpha(self, dtype, device):
|
||||
@ -1064,7 +1064,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
||||
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
||||
@ -1080,7 +1080,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
||||
dense_result = dense_result.to(out_dtype)
|
||||
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
|
||||
@inference_dtypes
|
||||
@ -1098,7 +1098,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
def test_cslt_sparse_mm_search(self, device, dtype):
|
||||
|
||||
Reference in New Issue
Block a user