mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplify ATen sparse semi-structured operators based on CUTLASS (#123473)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123473 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
635c238bad
commit
f5331aade5
@ -157,7 +157,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
"""
|
||||
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
|
||||
We expect:
|
||||
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
|
||||
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
|
||||
(2) Inductor should fuse the .contiguous() call into the relu
|
||||
"""
|
||||
|
||||
@ -207,7 +207,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
||||
def test_mlp_contiguous_relu_compile_cutlass(self):
|
||||
"""
|
||||
test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile
|
||||
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
|
||||
"""
|
||||
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
||||
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
|
||||
@ -258,7 +258,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
if dtype is torch.int8:
|
||||
# This should fail
|
||||
if backend == "cutlass":
|
||||
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
|
||||
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
@ -291,7 +291,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
# padding with int8 throws an error because transposing B yields a contiguous output
|
||||
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
|
||||
if backend == "cutlass":
|
||||
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
|
||||
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
@ -575,6 +575,73 @@ class TestSparseSemiStructured(TestCase):
|
||||
torch.backends.cuda.matmul.allow_tf32 = orig
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
|
||||
mat1 = rand_sparse_semi_structured(m, k, dtype, device)
|
||||
# mat2 transposed as int8 case supports only row-major/column-major combination
|
||||
mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
|
||||
input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None
|
||||
|
||||
if use_input:
|
||||
if dtype.is_floating_point:
|
||||
alpha = 1.3
|
||||
beta = -0.7
|
||||
else:
|
||||
alpha = 2
|
||||
beta = -3
|
||||
|
||||
dtype_dense = torch.float32
|
||||
mat1_dense = mat1.to(dtype_dense)
|
||||
mat2_dense = mat2.to(dtype_dense)
|
||||
if not use_input:
|
||||
output0 = torch.mm(mat1_dense, mat2_dense)
|
||||
else:
|
||||
input_dense = input.to(dtype_dense)[:, None]
|
||||
output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)
|
||||
|
||||
compressed = to_sparse_semi_structured(mat1)
|
||||
|
||||
mat1_sparse = compressed.values()
|
||||
mat1_meta = compressed.indices()
|
||||
|
||||
if not use_input:
|
||||
output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
|
||||
else:
|
||||
output1 = torch._sparse_semi_structured_addmm(
|
||||
input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
|
||||
)
|
||||
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
||||
|
||||
if dtype == torch.float32:
|
||||
# Inputs are converted to TF32 internally for sparse GEMM,
|
||||
# so make dense GEMM to do the same for matching results.
|
||||
orig = torch.backends.cuda.matmul.allow_tf32
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
|
||||
rtol, atol = 1e-3, 1e-3
|
||||
if dtype == torch.bfloat16:
|
||||
rtol, atol = 5e-3, 5e-3
|
||||
elif dtype == torch.float32:
|
||||
rtol, atol = 1e-3, 75e-2
|
||||
for m, n, k, use_input in \
|
||||
itertools.product(range(3), range(3), range(3), (False, True)):
|
||||
m = 2 ** m * 32
|
||||
n = 2 ** n * 32
|
||||
k = 2 ** k * 128
|
||||
run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)
|
||||
|
||||
if dtype == torch.float32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = orig
|
||||
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
|
Reference in New Issue
Block a user