mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add out_dtype support for sparse semi-structured CUTLASS back-end (#116519)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116519 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
ba06951c66
commit
f081c45a34
@ -3311,7 +3311,7 @@
|
||||
dispatch:
|
||||
CUDA: _cslt_sparse_mm_search
|
||||
|
||||
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None) -> Tensor
|
||||
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_linear
|
||||
|
||||
|
||||
@ -115,10 +115,11 @@ Tensor two_four_sgemm_cutlass(
|
||||
|
||||
// Determine PyTorch datatype for the output matrix.
|
||||
auto tensor_d_dtype = at::kChar;
|
||||
if constexpr (std::is_same_v<ElementOutput, int32_t>) {
|
||||
if constexpr (std::is_same_v<ElementOutput, int8_t>) {
|
||||
tensor_d_dtype = at::kChar;
|
||||
} else if constexpr (std::is_same_v<ElementOutput, int32_t>) {
|
||||
tensor_d_dtype = at::kInt;
|
||||
}
|
||||
else if constexpr (std::is_same_v<ElementOutput, cutlass::half_t>) {
|
||||
} else if constexpr (std::is_same_v<ElementOutput, cutlass::half_t>) {
|
||||
tensor_d_dtype = at::kHalf;
|
||||
} else if constexpr (std::is_same_v<ElementOutput, cutlass::bfloat16_t>) {
|
||||
tensor_d_dtype = at::kBFloat16;
|
||||
@ -470,7 +471,8 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
|
||||
Tensor _sparse_semi_structured_linear(
|
||||
const Tensor& input, const Tensor& weight,
|
||||
const Tensor& meta, const c10::optional<Tensor>& bias_opt,
|
||||
const c10::optional<c10::string_view> activation_opt) {
|
||||
const c10::optional<c10::string_view> activation_opt,
|
||||
const c10::optional<c10::ScalarType> out_dtype_opt) {
|
||||
#ifndef USE_ROCM
|
||||
// No need to check that all tensors are on CUDA device, as this
|
||||
// is provided by dispatch.
|
||||
@ -487,6 +489,12 @@ Tensor _sparse_semi_structured_linear(
|
||||
const auto activation =
|
||||
activation_opt.has_value() ? *activation_opt : "none";
|
||||
|
||||
TORCH_CHECK(!out_dtype_opt.has_value() ||
|
||||
(tensor_a.dtype() == at::ScalarType::Char &&
|
||||
out_dtype_opt.value() == at::ScalarType::Int),
|
||||
"_sparse_semi_structured_linear: Setting out_dtype is only "
|
||||
"supported for int8 input and int32 output");
|
||||
|
||||
// For now, only CC 8.x devices are supported.
|
||||
const auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
const auto is_sm8x = dprops->major == 8;
|
||||
@ -567,7 +575,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
[&]() {
|
||||
using ElementInputA = int8_t;
|
||||
using ElementInputB = int8_t;
|
||||
using ElementOutput = int32_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementComputeEpilogue = int32_t;
|
||||
using ThreadblockShape =
|
||||
@ -581,27 +588,53 @@ Tensor _sparse_semi_structured_linear(
|
||||
const auto EnableActivationNone = true;
|
||||
const auto EnableActivationReLU = true;
|
||||
const auto EnableActivationSiLU = false;
|
||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
||||
ElementInputA,
|
||||
ElementInputB,
|
||||
ElementOutput,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
EnableColumnMajorColumnMajorLayouts,
|
||||
EnableActivationNone,
|
||||
EnableActivationReLU,
|
||||
EnableActivationSiLU>(
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
tensor_c,
|
||||
meta,
|
||||
activation);
|
||||
if (out_dtype_opt.has_value()) {
|
||||
using ElementOutput = int32_t;
|
||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
||||
ElementInputA,
|
||||
ElementInputB,
|
||||
ElementOutput,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
EnableColumnMajorColumnMajorLayouts,
|
||||
EnableActivationNone,
|
||||
EnableActivationReLU,
|
||||
EnableActivationSiLU>(
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
tensor_c,
|
||||
meta,
|
||||
activation);
|
||||
} else {
|
||||
using ElementOutput = int8_t;
|
||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
||||
ElementInputA,
|
||||
ElementInputB,
|
||||
ElementOutput,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
EnableColumnMajorColumnMajorLayouts,
|
||||
EnableActivationNone,
|
||||
EnableActivationReLU,
|
||||
EnableActivationSiLU>(
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
tensor_c,
|
||||
meta,
|
||||
activation);
|
||||
}
|
||||
return;
|
||||
})
|
||||
AT_DISPATCH_CASE(
|
||||
|
||||
@ -136,6 +136,7 @@ ALLOW_LIST = [
|
||||
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
|
||||
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
|
||||
("aten::_efficient_attention_forward", datetime.date(2024, 1, 15)),
|
||||
("aten::_sparse_semi_structured_linear", datetime.date(2024, 1, 15)),
|
||||
("onednn::qconv1d_pointwise", datetime.date(2023, 12, 31)),
|
||||
("onednn::qconv2d_pointwise", datetime.date(2023, 12, 31)),
|
||||
("onednn::qconv3d_pointwise", datetime.date(2023, 12, 31)),
|
||||
|
||||
@ -290,9 +290,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
elif dtype is torch.int8:
|
||||
# test transpose
|
||||
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
|
||||
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
@ -335,7 +333,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
||||
sparse_result = torch.mm(A, B_sparse.t())
|
||||
else:
|
||||
dense_result = torch.mm(A, B.t())
|
||||
@ -444,7 +442,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
||||
if dtype == torch.int8:
|
||||
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
||||
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
|
||||
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
|
||||
B_t = B.t().contiguous()
|
||||
sparse_res = torch.mm(A_sparse, B_t.t())
|
||||
@ -509,7 +507,8 @@ class TestSparseSemiStructured(TestCase):
|
||||
weight_sparse = compressed.values()
|
||||
meta = compressed.indices()
|
||||
|
||||
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
|
||||
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
|
||||
out_dtype=dtype_out if dtype == torch.int8 else None)
|
||||
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
||||
|
||||
if dtype == torch.float32:
|
||||
|
||||
@ -401,6 +401,7 @@ def meta_sparse_structured_linear(
|
||||
_meta: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
_activation_opt: Optional[str] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
output_sizes = list(input.shape)
|
||||
if bias is not None:
|
||||
@ -415,9 +416,13 @@ def meta_sparse_structured_linear(
|
||||
assert len(input.shape) == 2, "we can only handle the squashed input case"
|
||||
transposed_strides = (1, input.size(0))
|
||||
|
||||
if out_dtype is not None:
|
||||
assert (
|
||||
input.dtype == torch.int8 and out_dtype == torch.int32
|
||||
), "out_dtype is only supported for i8i8->i32 linear operator"
|
||||
output = input.new_empty(
|
||||
output_sizes,
|
||||
dtype=input.dtype if input.dtype != torch.int8 else torch.int32,
|
||||
dtype=input.dtype if out_dtype is None else out_dtype,
|
||||
).as_strided(output_sizes, transposed_strides)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user