Simplify ATen sparse semi-structured operators based on CUTLASS (#123473)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123473
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Aleksandar Samardžić
2024-04-13 18:35:02 +00:00
committed by PyTorch MergeBot
parent 635c238bad
commit f5331aade5
9 changed files with 1139 additions and 185 deletions

View File

@ -157,7 +157,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
"""
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
We expect:
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
(2) Inductor should fuse the .contiguous() call into the relu
"""
@ -207,7 +207,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_mlp_contiguous_relu_compile_cutlass(self):
"""
test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
@ -258,7 +258,7 @@ class TestSparseSemiStructured(TestCase):
if dtype is torch.int8:
# This should fail
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B)
else:
with self.assertRaisesRegex(RuntimeError,
@ -291,7 +291,7 @@ class TestSparseSemiStructured(TestCase):
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,
@ -575,6 +575,73 @@ class TestSparseSemiStructured(TestCase):
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
mat1 = rand_sparse_semi_structured(m, k, dtype, device)
# mat2 transposed as int8 case supports only row-major/column-major combination
mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None
if use_input:
if dtype.is_floating_point:
alpha = 1.3
beta = -0.7
else:
alpha = 2
beta = -3
dtype_dense = torch.float32
mat1_dense = mat1.to(dtype_dense)
mat2_dense = mat2.to(dtype_dense)
if not use_input:
output0 = torch.mm(mat1_dense, mat2_dense)
else:
input_dense = input.to(dtype_dense)[:, None]
output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)
compressed = to_sparse_semi_structured(mat1)
mat1_sparse = compressed.values()
mat1_meta = compressed.indices()
if not use_input:
output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
else:
output1 = torch._sparse_semi_structured_addmm(
input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32:
# Inputs are converted to TF32 internally for sparse GEMM,
# so make dense GEMM to do the same for matching results.
orig = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = True
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 5e-3, 5e-3
elif dtype == torch.float32:
rtol, atol = 1e-3, 75e-2
for m, n, k, use_input in \
itertools.product(range(3), range(3), range(3), (False, True)):
m = 2 ** m * 32
n = 2 ** n * 32
k = 2 ** k * 128
run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)
if dtype == torch.float32:
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)