Add out_dtype support for sparse semi-structured CUTLASS back-end (#116519)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116519
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Aleksandar Samardžić
2023-12-28 21:59:59 +01:00
committed by PyTorch MergeBot
parent ba06951c66
commit f081c45a34
5 changed files with 72 additions and 34 deletions

View File

@ -3311,7 +3311,7 @@
dispatch:
CUDA: _cslt_sparse_mm_search
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None) -> Tensor
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor
dispatch:
CUDA: _sparse_semi_structured_linear

View File

@ -115,10 +115,11 @@ Tensor two_four_sgemm_cutlass(
// Determine PyTorch datatype for the output matrix.
auto tensor_d_dtype = at::kChar;
if constexpr (std::is_same_v<ElementOutput, int32_t>) {
if constexpr (std::is_same_v<ElementOutput, int8_t>) {
tensor_d_dtype = at::kChar;
} else if constexpr (std::is_same_v<ElementOutput, int32_t>) {
tensor_d_dtype = at::kInt;
}
else if constexpr (std::is_same_v<ElementOutput, cutlass::half_t>) {
} else if constexpr (std::is_same_v<ElementOutput, cutlass::half_t>) {
tensor_d_dtype = at::kHalf;
} else if constexpr (std::is_same_v<ElementOutput, cutlass::bfloat16_t>) {
tensor_d_dtype = at::kBFloat16;
@ -470,7 +471,8 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
Tensor _sparse_semi_structured_linear(
const Tensor& input, const Tensor& weight,
const Tensor& meta, const c10::optional<Tensor>& bias_opt,
const c10::optional<c10::string_view> activation_opt) {
const c10::optional<c10::string_view> activation_opt,
const c10::optional<c10::ScalarType> out_dtype_opt) {
#ifndef USE_ROCM
// No need to check that all tensors are on CUDA device, as this
// is provided by dispatch.
@ -487,6 +489,12 @@ Tensor _sparse_semi_structured_linear(
const auto activation =
activation_opt.has_value() ? *activation_opt : "none";
TORCH_CHECK(!out_dtype_opt.has_value() ||
(tensor_a.dtype() == at::ScalarType::Char &&
out_dtype_opt.value() == at::ScalarType::Int),
"_sparse_semi_structured_linear: Setting out_dtype is only "
"supported for int8 input and int32 output");
// For now, only CC 8.x devices are supported.
const auto dprops = at::cuda::getCurrentDeviceProperties();
const auto is_sm8x = dprops->major == 8;
@ -567,7 +575,6 @@ Tensor _sparse_semi_structured_linear(
[&]() {
using ElementInputA = int8_t;
using ElementInputB = int8_t;
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = int32_t;
using ThreadblockShape =
@ -581,27 +588,53 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationNone = true;
const auto EnableActivationReLU = true;
const auto EnableActivationSiLU = false;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
EnableActivationNone,
EnableActivationReLU,
EnableActivationSiLU>(
tensor_a,
tensor_b,
tensor_c,
meta,
activation);
if (out_dtype_opt.has_value()) {
using ElementOutput = int32_t;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
EnableActivationNone,
EnableActivationReLU,
EnableActivationSiLU>(
tensor_a,
tensor_b,
tensor_c,
meta,
activation);
} else {
using ElementOutput = int8_t;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
EnableActivationNone,
EnableActivationReLU,
EnableActivationSiLU>(
tensor_a,
tensor_b,
tensor_c,
meta,
activation);
}
return;
})
AT_DISPATCH_CASE(

View File

@ -136,6 +136,7 @@ ALLOW_LIST = [
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
("aten::_efficient_attention_forward", datetime.date(2024, 1, 15)),
("aten::_sparse_semi_structured_linear", datetime.date(2024, 1, 15)),
("onednn::qconv1d_pointwise", datetime.date(2023, 12, 31)),
("onednn::qconv2d_pointwise", datetime.date(2023, 12, 31)),
("onednn::qconv3d_pointwise", datetime.date(2023, 12, 31)),

View File

@ -290,9 +290,7 @@ class TestSparseSemiStructured(TestCase):
sparse_result = torch.mm(A_sparse, B.t())
elif dtype is torch.int8:
# test transpose
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
@ -335,7 +333,7 @@ class TestSparseSemiStructured(TestCase):
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
sparse_result = torch.mm(A, B_sparse.t())
else:
dense_result = torch.mm(A, B.t())
@ -444,7 +442,7 @@ class TestSparseSemiStructured(TestCase):
A_sparse = to_sparse_semi_structured(A)
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
if dtype == torch.int8:
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
B_t = B.t().contiguous()
sparse_res = torch.mm(A_sparse, B_t.t())
@ -509,7 +507,8 @@ class TestSparseSemiStructured(TestCase):
weight_sparse = compressed.values()
meta = compressed.indices()
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
out_dtype=dtype_out if dtype == torch.int8 else None)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32:

View File

@ -401,6 +401,7 @@ def meta_sparse_structured_linear(
_meta: Tensor,
bias: Optional[Tensor] = None,
_activation_opt: Optional[str] = None,
out_dtype: Optional[torch.dtype] = None,
):
output_sizes = list(input.shape)
if bias is not None:
@ -415,9 +416,13 @@ def meta_sparse_structured_linear(
assert len(input.shape) == 2, "we can only handle the squashed input case"
transposed_strides = (1, input.size(0))
if out_dtype is not None:
assert (
input.dtype == torch.int8 and out_dtype == torch.int32
), "out_dtype is only supported for i8i8->i32 linear operator"
output = input.new_empty(
output_sizes,
dtype=input.dtype if input.dtype != torch.int8 else torch.int32,
dtype=input.dtype if out_dtype is None else out_dtype,
).as_strided(output_sizes, transposed_strides)
return output