mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] Add fast semi-structured spasification kernels (#122350)
This PR adds in fast semi-structured sparsification kernels to PyTorch. These kernels allow for accelerated semi-structured sparsification kernels in PyTorch. The kernels have been added as aten native functions In particular, three new functions have been added: * `torch._sparse_semi_structured_tile` This function will return the packed representation and metadata for both X and X', as well as the thread masks. Note that this applies 2:4 sparsity in a 4x4 tile instead of a 1x4 strip as usual. * `torch._sparse_semi_structured_apply` This function takes in an input tensor and thread masks from the above function and returns a packed representation and metadata from applying thread masks to the input tensor. * `torch._sparse_semi_structured_apply_dense` This function does the same thing as above but instead of returning the tensor in the sparse representation it returns it in the dense representation The subclasses have also been updated to add a new `prune_dense_static_sort` classmethod to create sparse tensors with this format. I've added some additional documentatino on how to calculate the compressed tensors needed to create a SparseSemiStructuredTensor oneself. To this end, there are two new helper functions added: `sparse_semi_structured_tile` `compute_compressed_swizzled_bitmask` Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
d8717c2d68
commit
c63a7b5691
@ -3342,6 +3342,18 @@
|
||||
dispatch:
|
||||
CUDA: _cslt_sparse_mm_search
|
||||
|
||||
- func: _sparse_semi_structured_tile(Tensor input, str algorithm="", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_tile
|
||||
|
||||
- func: _sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_apply
|
||||
|
||||
- func: _sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_apply_dense
|
||||
|
||||
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_linear
|
||||
|
184
aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h
Normal file
184
aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h
Normal file
@ -0,0 +1,184 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#include <ATen/native/sparse/cuda/StaticSort.h>
|
||||
#include <cutlass/bfloat16.h>
|
||||
#include <cutlass/half.h>
|
||||
|
||||
// Given 4x4 values, computes the selected indices that will remain after 2:4
|
||||
// sparsification, as a bitmask.
|
||||
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
|
||||
|
||||
namespace platform {
|
||||
template <>
|
||||
struct numeric_limits<cutlass::bfloat16_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t infinity() {
|
||||
return cutlass::bfloat16_t::bitcast(0x7f80);
|
||||
}
|
||||
};
|
||||
} // namespace platform
|
||||
|
||||
namespace at::native{
|
||||
|
||||
template <typename Element, typename Pointwise>
|
||||
struct TileValueOrderedT {
|
||||
union {
|
||||
struct {
|
||||
Element value;
|
||||
uint2b_t col;
|
||||
uint2b_t row;
|
||||
} parts;
|
||||
uint32_t raw;
|
||||
};
|
||||
CUTLASS_DEVICE bool operator<(
|
||||
TileValueOrderedT<Element, Pointwise> const& other) const {
|
||||
return Pointwise::apply(parts.value) < Pointwise::apply(other.parts.value);
|
||||
}
|
||||
CUTLASS_DEVICE TileValueOrderedT() {}
|
||||
};
|
||||
|
||||
// Operations that we can apply to rank the values
|
||||
struct IdentityOp {
|
||||
template <typename T>
|
||||
static T CUTLASS_HOST_DEVICE apply(T const& x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
// Can be applied to rank based on absolute value
|
||||
struct AbsOp {
|
||||
template <typename T>
|
||||
static T CUTLASS_HOST_DEVICE apply(T const& x) {
|
||||
return cutlass::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
// Given 4x4 values, computes the selected indices that will remain after 2:4
|
||||
// sparsification, as a bitmask. We have 2 constraints:
|
||||
// (1) At most 2 values per line
|
||||
// (2) At most 2 values per column
|
||||
// This means we can select at most 8 values in total.
|
||||
// ALGO: We use a greedy algorithm, where we take values in the 4x4
|
||||
// tile in descending order. If a value fits (because the line/col is not
|
||||
// already full), we select it. Then we move on to the next one.
|
||||
// NOTE: This algorithm might select LESS than 8 values in total in some cases.
|
||||
// NOTE (2): RF are not indexable, so we shouldn't rely on indexing
|
||||
// values at any point, otherwise they will be stored in local memory.
|
||||
template <typename Op = IdentityOp>
|
||||
struct LargestValuesGreedy {
|
||||
template <typename T>
|
||||
static CUTLASS_DEVICE T outOfBoundsFillValue() {
|
||||
return -platform::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
template <typename Tile4x4Accessor>
|
||||
CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) {
|
||||
using TileValueOrdered =
|
||||
TileValueOrderedT<typename Tile4x4Accessor::Element, Op>;
|
||||
using TileValuesFragment = cutlass::Array<TileValueOrdered, 4 * 4>;
|
||||
Indices4x4 indices;
|
||||
TileValuesFragment values_ordered;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
TileValueOrdered& v = values_ordered[i * 4 + j];
|
||||
v.parts.value = values.at(i, j).get();
|
||||
v.parts.col = j;
|
||||
v.parts.row = i;
|
||||
}
|
||||
}
|
||||
// Use a sorting network (aka without branches) to avoid
|
||||
// warp divergence
|
||||
StaticSort<TileValuesFragment::kElements> sorter;
|
||||
sorter(values_ordered);
|
||||
|
||||
// bitmask to store how many we have selected on a given row/col
|
||||
// 0 selected: (numPerRow >> 2*row) = 00 (0)
|
||||
// 1 selected: (numPerRow >> 2*row) = 01 (1)
|
||||
// 2 selected: (numPerRow >> 2*row) = 11 (3)
|
||||
uint32_t numPerRow = 0;
|
||||
uint32_t numPerCol = 0;
|
||||
indices = 0;
|
||||
|
||||
// Take as many as we can, starting with the largest values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = values_ordered.size() - 1; i >= 0; i--) {
|
||||
auto& e = values_ordered[i];
|
||||
|
||||
uint32_t rcount = uint2b_t(numPerRow >> 2 * e.parts.row);
|
||||
uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col);
|
||||
// NOTE: This is more efficient (yet equivalent) to:
|
||||
// `rcount != 3 && ccount != 3`
|
||||
bool selected = (rcount + ccount) <= 2;
|
||||
indices |= selected << (e.parts.col + 4 * e.parts.row);
|
||||
|
||||
numPerRow |= (rcount + selected) << 2 * e.parts.row;
|
||||
numPerCol |= (ccount + selected) << 2 * e.parts.col;
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
};
|
||||
|
||||
// We consider each rows independantly in order
|
||||
// This is to ensure that a row's sparsity pattern is only determined
|
||||
// by its values and the rows before (but never the rows after)
|
||||
// This enforces causality strictly
|
||||
template <typename Op = IdentityOp>
|
||||
struct Causal1122 {
|
||||
template <typename T>
|
||||
static CUTLASS_DEVICE T outOfBoundsFillValue() {
|
||||
return -platform::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
template <typename Tile4x4Accessor>
|
||||
CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) {
|
||||
static constexpr int kMaxValuesPerRow[] = {1, 1, 2, 2};
|
||||
using TileValueOrdered =
|
||||
TileValueOrderedT<typename Tile4x4Accessor::Element, Op>;
|
||||
using TileValuesFragment = cutlass::Array<TileValueOrdered, 4>;
|
||||
Indices4x4 indices = 0;
|
||||
|
||||
uint32_t numPerCol = 0; // <- see doc in `LargestValuesGreedy`
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < 4; ++row) {
|
||||
int row_count = 0;
|
||||
TileValuesFragment values_ordered;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int col = 0; col < 4; ++col) {
|
||||
TileValueOrdered& v = values_ordered[col];
|
||||
v.parts.value = values.at(row, col).get();
|
||||
v.parts.col = col;
|
||||
}
|
||||
// Use a sorting network (aka without branches) to avoid
|
||||
// warp divergence
|
||||
StaticSort<TileValuesFragment::kElements> sorter;
|
||||
sorter(values_ordered);
|
||||
|
||||
// Take as many as we can, starting with the largest values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = values_ordered.size() - 1; i >= 0; i--) {
|
||||
auto& e = values_ordered[i];
|
||||
|
||||
uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col);
|
||||
bool selected = ccount != 3 && (row_count < kMaxValuesPerRow[row]);
|
||||
indices |= selected << (e.parts.col + 4 * row);
|
||||
numPerCol |= (ccount + selected) << 2 * e.parts.col;
|
||||
row_count += selected;
|
||||
}
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void named_algorithms(T callback) {
|
||||
callback(LargestValuesGreedy<IdentityOp>(), "largest_values_greedy");
|
||||
callback(Causal1122<IdentityOp>(), "causal1122");
|
||||
callback(LargestValuesGreedy<AbsOp>(), "largest_abs_values_greedy");
|
||||
// default one
|
||||
callback(LargestValuesGreedy<IdentityOp>(), "");
|
||||
}
|
||||
|
||||
} // namespace
|
@ -0,0 +1,174 @@
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/native/sparse/cuda/ComputeSparseTile.h>
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
struct Params {
|
||||
uint64_t const* threads_masks;
|
||||
|
||||
uint16_t const* input;
|
||||
int64_t input_stride;
|
||||
int64_t input_dim0;
|
||||
int64_t input_dim1;
|
||||
|
||||
uint16_t* output;
|
||||
int64_t output_stride;
|
||||
|
||||
__host__ dim3 getBlocksGrid() const {
|
||||
return dim3(
|
||||
cutlass::ceil_div(input_dim0, kWarpX),
|
||||
cutlass::ceil_div(input_dim1, kWarpY),
|
||||
1);
|
||||
}
|
||||
|
||||
static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
|
||||
return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
|
||||
Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
|
||||
gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
|
||||
int64_t strideX = gridDim.y * getThreadsGrid().y;
|
||||
gmem_threads_masks +=
|
||||
(blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
|
||||
return gmem_threads_masks;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool kInputRowMajor = true, bool kOutputRowMajor = true>
|
||||
__global__ void __launch_bounds__(32 /* num_threads */, 32) sparse_semi_structured_apply_dense_k(Params p) {
|
||||
using Fragment = cutlass::Array<uint16_t, 8>;
|
||||
|
||||
// Top-left of the 8x8 tile we own
|
||||
int warp_x = blockIdx.x * kWarpX;
|
||||
int warp_y = blockIdx.y * kWarpY;
|
||||
int x = warp_x + threadIdx.x * kThreadX;
|
||||
int y = warp_y + threadIdx.y * kThreadY;
|
||||
|
||||
uint16_t* output = p.output + x * p.output_stride + y;
|
||||
Tile8x8Masks indices = *p.getCurrentThreadIndices();
|
||||
|
||||
// Load dense
|
||||
Fragment lines[8];
|
||||
if (kInputRowMajor) {
|
||||
uint16_t const* input = p.input + x * p.input_stride + y;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_stride, true);
|
||||
}
|
||||
} else {
|
||||
uint16_t const* input = p.input + x + y * p.input_stride;
|
||||
Fragment columns[8];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
columns[i], input + i * p.input_stride, true);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
lines[i][j] = columns[j][i].get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
Indices4x4 masks[2];
|
||||
if (row == 0) {
|
||||
masks[0] = indices.a;
|
||||
masks[1] = indices.b;
|
||||
} else {
|
||||
masks[0] = indices.c;
|
||||
masks[1] = indices.d;
|
||||
}
|
||||
|
||||
// Apply mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < 2; ++m) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int r = 0; r < 4; ++r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < 4; ++c) {
|
||||
lines[4 * row + r][4 * m + c] = lines[4 * row + r][4 * m + c] *
|
||||
int((masks[m] >> (4 * r + c)) & 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static_assert(kOutputRowMajor, "Transpose here for ColMajor output");
|
||||
// Save dense with zeros
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
cutlass::arch::global_store<Fragment, sizeof(Fragment)>(
|
||||
lines[i], output + i * p.output_stride, true);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _sparse_semi_structured_apply_dense(
|
||||
const Tensor& input,
|
||||
const Tensor& threads_masks) {
|
||||
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported `input` dtype");
|
||||
TORCH_CHECK(
|
||||
input.stride(0) == 1 || input.stride(1) == 1,
|
||||
"`input` should be either RowMajor or ColMajor. Invalid memory layout - try .contiguous()?");
|
||||
|
||||
auto roundedx = cutlass::round_up(input.size(0), kWarpX);
|
||||
auto roundedy = cutlass::round_up(input.size(1), kWarpY);
|
||||
|
||||
Params p;
|
||||
p.input = (uint16_t const*)input.data_ptr();
|
||||
p.input_dim0 = input.size(0);
|
||||
p.input_dim1 = input.size(1);
|
||||
p.threads_masks = (uint64_t const*)threads_masks.data_ptr();
|
||||
|
||||
TORCH_CHECK(threads_masks.dim() == 3);
|
||||
TORCH_CHECK(threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
|
||||
TORCH_CHECK(threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
|
||||
TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.stride(2) == 1);
|
||||
TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
|
||||
|
||||
at::Tensor output = at::empty({p.input_dim0, p.input_dim1}, input.options());
|
||||
TORCH_INTERNAL_ASSERT(output.stride(-1) == 1, "expected RowMajor?");
|
||||
p.output = (uint16_t*)output.data_ptr();
|
||||
|
||||
bool inputRowMajor = input.stride(-1) == 1;
|
||||
bool outputRowMajor = output.stride(-1) == 1;
|
||||
p.input_stride = input.stride(inputRowMajor ? 0 : 1);
|
||||
p.output_stride = output.stride(outputRowMajor ? 0 : 1);
|
||||
at::cuda::CUDAGuard device_guard(input.device());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
size_t smem_bytes = 0;
|
||||
if (inputRowMajor && outputRowMajor) {
|
||||
sparse_semi_structured_apply_dense_k<true, true>
|
||||
<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
|
||||
} else if (!inputRowMajor && outputRowMajor) {
|
||||
sparse_semi_structured_apply_dense_k<false, true>
|
||||
<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported configuration: `input` is ",
|
||||
inputRowMajor ? "RowMajor" : "ColMajor",
|
||||
", and `output` is ",
|
||||
outputRowMajor ? "RowMajor" : "ColMajor");
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace
|
520
aten/src/ATen/native/sparse/cuda/SparseSemiStructuredPack.h
Normal file
520
aten/src/ATen/native/sparse/cuda/SparseSemiStructuredPack.h
Normal file
@ -0,0 +1,520 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/sparse/cuda/StaticSort.h>
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/bfloat16.h>
|
||||
#include <cutlass/fast_math.h>
|
||||
#include <cutlass/half.h>
|
||||
#include <cutlass/integer_subbyte.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using cutlass::uint1b_t;
|
||||
using cutlass::uint2b_t;
|
||||
using cutlass::uint4b_t;
|
||||
using uint8b_t = cutlass::integer_subbyte<8, false>;
|
||||
using ReorderedLayoutInputE = cutlass::layout::ColumnMajorInterleaved<2>;
|
||||
using ElementInputE = uint16_t;
|
||||
constexpr int kWarpX = 32;
|
||||
constexpr int kWarpY = 64;
|
||||
constexpr int kThreadX = 8;
|
||||
constexpr int kThreadY = 8;
|
||||
|
||||
// bitmask of selected values, in col-major storage
|
||||
// eg: indices & (1 << (col + 4 * row))
|
||||
using Indices4x4 = uint16_t;
|
||||
|
||||
struct Tile8x8Masks {
|
||||
Indices4x4 a, b, c, d;
|
||||
CUTLASS_DEVICE Tile8x8Masks() {
|
||||
a = b = c = d = 0;
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(sizeof(Tile8x8Masks) == 8, "should be exactly uint64_t");
|
||||
|
||||
// Each thread has data for an 8x8 area of the input tensor
|
||||
// Due to the very specific format of the metadata, 32 consecutive bits
|
||||
// of the metadata tensor will live in 4 different threads.
|
||||
// This functions does the required warp shuffling to send data to the
|
||||
// right threads.
|
||||
// This took some time to write (and get right), hopefully these slides
|
||||
// can help
|
||||
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g249eb2e2f2e_0_28
|
||||
CUTLASS_DEVICE uint32_t
|
||||
warp_shuffle_meta(uint32_t meta_ab, bool transposed = false) {
|
||||
// The required format is
|
||||
// (one line = 32 bits)
|
||||
// a[ 0, 0:16] a[ 8, 0:16] <- T0 [left]
|
||||
// a[ 0, 16:32] a[ 8, 16:32]
|
||||
// a[16, 0:16] a[24, 0:16]
|
||||
// a[16, 16:32] a[24, 16:32]
|
||||
// a[ 1, 0:16] a[ 9, 0:16] <- T4
|
||||
// a[ 1, 16:32] a[ 9, 16:32]
|
||||
// a[17, 0:16] a[25, 0:16]
|
||||
// a[17, 16:32] a[25, 16:32]
|
||||
// a[ 2, 0:16] a[10, 0:16] <- T1 [left, bottom]
|
||||
// a[ 2, 16:32] a[10, 16:32]
|
||||
// a[18, 0:16] a[26, 0:16]
|
||||
// a[18, 16:32] a[26, 16:32]
|
||||
// a[ 3, 0:16] a[11, 0:16] <- T5 [bottom]
|
||||
// a[ 3, 16:32] a[11, 16:32]
|
||||
// a[19, 0:16] a[27, 0:16]
|
||||
// a[19, 16:32] a[27, 16:32]
|
||||
// ...
|
||||
// Use warp-shuffles to send data around threads
|
||||
bool thread_left = (threadIdx.y % 2) == 0;
|
||||
bool thread_bottom = threadIdx.x % 2;
|
||||
|
||||
if (transposed) {
|
||||
thread_left = (threadIdx.x % 2) == 0;
|
||||
thread_bottom = threadIdx.y % 2;
|
||||
}
|
||||
|
||||
uint8b_t stage0_data[2] = {
|
||||
uint8b_t(meta_ab >> (8 * thread_left)),
|
||||
uint8b_t(meta_ab >> (8 * (thread_left + 2)))};
|
||||
// shfl t0-t4 / t1-t5
|
||||
stage0_data[0] =
|
||||
__shfl_xor_sync(0xffffffff, stage0_data[0], transposed ? 1 : 4);
|
||||
stage0_data[1] =
|
||||
__shfl_xor_sync(0xffffffff, stage0_data[1], transposed ? 1 : 4);
|
||||
|
||||
uint16_t line0 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left))))
|
||||
<< ((1 - thread_left) * 8);
|
||||
line0 |= int(stage0_data[0]) << (thread_left * 8);
|
||||
uint16_t line1 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left + 2))))
|
||||
<< ((1 - thread_left) * 8);
|
||||
line1 |= int(stage0_data[1]) << (thread_left * 8);
|
||||
|
||||
uint16_t stage1_data = thread_bottom ? line0 : line1;
|
||||
stage1_data = __shfl_xor_sync(0xffffffff, stage1_data, transposed ? 4 : 1);
|
||||
|
||||
uint32_t final_metadata;
|
||||
if (thread_bottom) {
|
||||
final_metadata = uint32_t(stage1_data) | uint32_t(line1) << 16;
|
||||
} else {
|
||||
final_metadata = uint32_t(stage1_data) << 16 | uint32_t(line0);
|
||||
}
|
||||
return final_metadata;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void warp_shuffle_and_write_meta(
|
||||
ElementInputE* metadata_quad,
|
||||
uint32_t meta_ab,
|
||||
bool transposed = false) {
|
||||
bool thread_left = (threadIdx.y % 2) == 0;
|
||||
bool thread_bottom = threadIdx.x % 2;
|
||||
|
||||
if (transposed) {
|
||||
thread_left = (threadIdx.x % 2) == 0;
|
||||
thread_bottom = threadIdx.y % 2;
|
||||
}
|
||||
|
||||
uint32_t final_metadata = warp_shuffle_meta(meta_ab, transposed);
|
||||
|
||||
int index = (!thread_left + 2 * thread_bottom) * 4;
|
||||
((uint32_t*)metadata_quad)[index] = final_metadata;
|
||||
}
|
||||
|
||||
template <typename Element_>
|
||||
struct KernelTypes {
|
||||
using Element = Element_;
|
||||
using Fragment =
|
||||
cutlass::Array<Element, 8>; // always read from gmem in chunks of 128bits
|
||||
using Fragment4 = cutlass::Array<Element, 4>;
|
||||
using ValuesPacked = cutlass::Array<Element, 8>; // 4 first col, 4 second col
|
||||
|
||||
struct Params {
|
||||
/// inputs
|
||||
Element const* input;
|
||||
int64_t input_s0;
|
||||
int64_t input_dim0;
|
||||
int64_t input_dim1;
|
||||
|
||||
/// outputs
|
||||
Element* packed;
|
||||
int64_t packed_stride;
|
||||
|
||||
Element* packed_trans;
|
||||
int64_t packed_trans_stride;
|
||||
|
||||
uint64_t* threads_masks;
|
||||
|
||||
__host__ dim3 getBlocksGrid() const {
|
||||
return dim3(
|
||||
cutlass::ceil_div(input_dim0, kWarpX),
|
||||
cutlass::ceil_div(input_dim1, kWarpY),
|
||||
1);
|
||||
}
|
||||
|
||||
static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
|
||||
return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
|
||||
Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
|
||||
gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
|
||||
int64_t strideX = gridDim.y * getThreadsGrid().y;
|
||||
gmem_threads_masks +=
|
||||
(blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
|
||||
return gmem_threads_masks;
|
||||
}
|
||||
};
|
||||
|
||||
struct Tile4x4Accessor {
|
||||
using Element = Element_;
|
||||
|
||||
Fragment (&_lines)[8];
|
||||
int _start_row;
|
||||
int _start_col;
|
||||
|
||||
CUTLASS_DEVICE Tile4x4Accessor(
|
||||
Fragment (&lines)[8],
|
||||
int start_row,
|
||||
int start_col)
|
||||
: _lines(lines), _start_row(start_row), _start_col(start_col) {}
|
||||
|
||||
CUTLASS_DEVICE typename Fragment::reference at(int r, int c) {
|
||||
return _lines[r + _start_row][c + _start_col];
|
||||
}
|
||||
};
|
||||
|
||||
struct Tile4x4Packed {
|
||||
Fragment4 values[2];
|
||||
CUTLASS_DEVICE Tile4x4Packed() {
|
||||
values[0].clear();
|
||||
values[1].clear();
|
||||
}
|
||||
};
|
||||
|
||||
// Returns a packed 4x4 tile (eg 2x4 values) which correspond to the values
|
||||
// that are in `indices`. Also fills the `meta` array in the right format
|
||||
// for consumption in the TensorCores.
|
||||
// Example:
|
||||
// indices: 0011
|
||||
// 1001
|
||||
// 1001
|
||||
// 0100 (<- note, only 1 value on the last line)
|
||||
// packed: values[0][2] values[1][0] values[2][0] values[3][1]
|
||||
// values[0][3] values[1][3] values[2][3] Element(0)
|
||||
CUTLASS_DEVICE static Tile4x4Packed pack_4x4(
|
||||
Indices4x4 indices,
|
||||
Tile4x4Accessor tile,
|
||||
uint32_t& meta,
|
||||
int meta_pos,
|
||||
bool transpose = false) {
|
||||
Tile4x4Packed packed;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < 4; ++row) {
|
||||
uint2b_t col0_from, col1_from;
|
||||
auto packValue = [&](uint2b_t col_to, uint2b_t col_from) {
|
||||
auto value = transpose ? tile.at(col_from, row).get()
|
||||
: tile.at(row, col_from).get();
|
||||
packed.values[col_to][row] = value;
|
||||
if (col_to == uint2b_t(0)) {
|
||||
col0_from = col_from;
|
||||
} else {
|
||||
col1_from = col_from;
|
||||
}
|
||||
};
|
||||
auto isSelected = [&](int col) {
|
||||
if (transpose) {
|
||||
return indices & (1 << (row + 4 * col));
|
||||
}
|
||||
return indices & (1 << (col + 4 * row));
|
||||
};
|
||||
// Process cols 0/1
|
||||
// We know that col0 is always packed to position 0 if it's there
|
||||
// and col1 is packed to pos 0 or 1 (depending if col0 is selected)
|
||||
if (isSelected(1)) {
|
||||
packValue(0, 1);
|
||||
}
|
||||
if (isSelected(0)) {
|
||||
packValue(0, 0);
|
||||
}
|
||||
if (isSelected(0) && isSelected(1)) {
|
||||
packValue(1, 1);
|
||||
}
|
||||
// Process cols 2/3
|
||||
// same sort of heuristic
|
||||
if (isSelected(2)) {
|
||||
packValue(1, 2);
|
||||
}
|
||||
if (isSelected(3)) {
|
||||
packValue(1, 3);
|
||||
}
|
||||
if (isSelected(2) && isSelected(3)) {
|
||||
packValue(0, 2);
|
||||
}
|
||||
int add_mask = (col0_from | (col1_from << 2)) << (8 * row + meta_pos);
|
||||
meta |= add_mask;
|
||||
}
|
||||
return packed;
|
||||
}
|
||||
|
||||
struct Tile8x8Meta {
|
||||
// meta_ab[row] |= (real_col << (8*row + 2*pos))
|
||||
uint32_t meta_ab;
|
||||
uint32_t meta_cd;
|
||||
|
||||
// meta_ac_trans[col] |= (real_row << (8*col + 2*pos))
|
||||
uint32_t meta_ac_trans;
|
||||
uint32_t meta_bd_trans;
|
||||
|
||||
CUTLASS_DEVICE Tile8x8Meta() {
|
||||
meta_ab = meta_cd = meta_ac_trans = meta_bd_trans = 0;
|
||||
}
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE static void writePacked(
|
||||
Element* ptr,
|
||||
Fragment4 packed0,
|
||||
Fragment4 packed1) {
|
||||
Fragment write;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
write[i] = packed0[i].get();
|
||||
write[i + 4] = packed1[i].get();
|
||||
}
|
||||
cutlass::arch::global_store<Fragment, sizeof(Fragment)>(write, ptr, true);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE static void writePackedT(
|
||||
Element* ptr,
|
||||
int64_t stride,
|
||||
Tile4x4Packed a,
|
||||
Tile4x4Packed b) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
Fragment4 write;
|
||||
write[0] = a.values[0][i].get();
|
||||
write[1] = a.values[1][i].get();
|
||||
write[2] = b.values[0][i].get();
|
||||
write[3] = b.values[1][i].get();
|
||||
cutlass::arch::global_store<Fragment4, sizeof(Fragment4)>(
|
||||
write, ptr + i * stride, true);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Algorithm, typename MetadataStore>
|
||||
CUTLASS_DEVICE static void sparse_semi_structured_tile_kernel(
|
||||
Params p,
|
||||
MetadataStore metadata_gmem,
|
||||
Algorithm compute_tile_indices) {
|
||||
// Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
|
||||
// A, B, C and D, as displayed in the following schema:
|
||||
// +---+---+
|
||||
// | A | B |
|
||||
// +---+---+
|
||||
// | C | D |
|
||||
// +---+---+
|
||||
// Each warp (32 threads) will then be responsible for a 32x64 tile of the
|
||||
// input.
|
||||
// This configuration allows to read/write data in 128bits chunks. These
|
||||
// memory accesses are coalesced at the warp-level into 128bytes. See also:
|
||||
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g2494f30c7cf_0_0
|
||||
|
||||
// Top-left of the 8x8 tile we own
|
||||
int warp_x = blockIdx.x * kWarpX;
|
||||
int warp_y = blockIdx.y * kWarpY;
|
||||
int x = warp_x + threadIdx.x * kThreadX;
|
||||
int y = warp_y + threadIdx.y * kThreadY;
|
||||
|
||||
Element const* input = p.input + x * p.input_s0 + y;
|
||||
Element* packed = p.packed + x * p.packed_stride + (y / 2);
|
||||
Element* packed_trans =
|
||||
p.packed_trans + (x / 2) + y * p.packed_trans_stride;
|
||||
|
||||
Fragment lines[8]; // Contains all values from the 8x8 tile
|
||||
|
||||
Tile8x8Meta metadata;
|
||||
Tile8x8Masks indices;
|
||||
|
||||
// Load/process tiles `A` and `B`
|
||||
Element fillValue = Algorithm::template outOfBoundsFillValue<Element>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
lines[i].fill(fillValue);
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
|
||||
}
|
||||
indices.a = compute_tile_indices(Tile4x4Accessor(lines, 0, 0));
|
||||
indices.b = compute_tile_indices(Tile4x4Accessor(lines, 0, 4));
|
||||
|
||||
// Compute packed tiles A & B
|
||||
{
|
||||
Tile4x4Packed packed_a = pack_4x4(
|
||||
indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
|
||||
Tile4x4Packed packed_b = pack_4x4(
|
||||
indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
|
||||
writePackedT(packed, p.packed_stride, packed_a, packed_b);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles A & B in transpose output
|
||||
Tile4x4Packed packed_trans_a = pack_4x4(
|
||||
indices.a,
|
||||
Tile4x4Accessor(lines, 0, 0),
|
||||
metadata.meta_ac_trans,
|
||||
0,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_b = pack_4x4(
|
||||
indices.b,
|
||||
Tile4x4Accessor(lines, 0, 4),
|
||||
metadata.meta_bd_trans,
|
||||
0,
|
||||
true);
|
||||
// (NOTE) Now we no longer need A & B (`lines[0:4]`)
|
||||
|
||||
// Load/process tiles `C` and `D`
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 4; i < 8; ++i) {
|
||||
lines[i].fill(fillValue);
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
|
||||
}
|
||||
indices.c = compute_tile_indices(Tile4x4Accessor(lines, 4, 0));
|
||||
indices.d = compute_tile_indices(Tile4x4Accessor(lines, 4, 4));
|
||||
|
||||
// Compute packed tiles C & D
|
||||
{
|
||||
Tile4x4Packed packed_c = pack_4x4(
|
||||
indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
|
||||
Tile4x4Packed packed_d = pack_4x4(
|
||||
indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
|
||||
writePackedT(
|
||||
packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles C & D in transpose output
|
||||
Tile4x4Packed packed_trans_c = pack_4x4(
|
||||
indices.c,
|
||||
Tile4x4Accessor(lines, 4, 0),
|
||||
metadata.meta_ac_trans,
|
||||
4,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_d = pack_4x4(
|
||||
indices.d,
|
||||
Tile4x4Accessor(lines, 4, 4),
|
||||
metadata.meta_bd_trans,
|
||||
4,
|
||||
true);
|
||||
|
||||
// Dump the metadata in a nice format
|
||||
*p.getCurrentThreadIndices() = indices;
|
||||
|
||||
// Store packed A, B, C & D for transposed matrix
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
|
||||
packed_trans += 4 * p.packed_trans_stride;
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
|
||||
|
||||
// Writing meta non-transposed
|
||||
{
|
||||
ElementInputE* packed_meta_reordered = metadata_gmem.get_metaN(
|
||||
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
|
||||
warp_shuffle_and_write_meta(packed_meta_reordered, metadata.meta_ab);
|
||||
warp_shuffle_and_write_meta(packed_meta_reordered + 32, metadata.meta_cd);
|
||||
}
|
||||
|
||||
// Writing meta transposed
|
||||
{
|
||||
ElementInputE* packed_trans_meta_reordered = metadata_gmem.get_metaT(
|
||||
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
|
||||
warp_shuffle_and_write_meta(
|
||||
packed_trans_meta_reordered, metadata.meta_ac_trans, true);
|
||||
warp_shuffle_and_write_meta(
|
||||
packed_trans_meta_reordered + 32, metadata.meta_bd_trans, true);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE static void sparse_semi_structured_apply_kernel(Params p) {
|
||||
// See `sparse24_sparsify_both_ways_kernel`
|
||||
// It's basically the same, just that we skip
|
||||
// the part where compute the indices we keep
|
||||
|
||||
// Top-left of the 8x8 tile we own
|
||||
int warp_x = blockIdx.x * kWarpX;
|
||||
int warp_y = blockIdx.y * kWarpY;
|
||||
int x = warp_x + threadIdx.x * kThreadX;
|
||||
int y = warp_y + threadIdx.y * kThreadY;
|
||||
|
||||
Element const* input = p.input + x * p.input_s0 + y;
|
||||
Element* packed = p.packed + x * p.packed_stride + (y / 2);
|
||||
Element* packed_trans =
|
||||
p.packed_trans + (x / 2) + y * p.packed_trans_stride;
|
||||
|
||||
Fragment lines[8]; // Contains all values from the 8x8 tile
|
||||
|
||||
Tile8x8Meta metadata;
|
||||
Tile8x8Masks indices = *p.getCurrentThreadIndices();
|
||||
|
||||
// Load/process tiles `A` and `B`
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
// NB: Values outside bounds is undefined, but shouldn't
|
||||
// be used anywhere
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
|
||||
}
|
||||
|
||||
// Compute packed tiles A & B
|
||||
{
|
||||
Tile4x4Packed packed_a = pack_4x4(
|
||||
indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
|
||||
Tile4x4Packed packed_b = pack_4x4(
|
||||
indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
|
||||
writePackedT(packed, p.packed_stride, packed_a, packed_b);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles A & B in transpose output
|
||||
Tile4x4Packed packed_trans_a = pack_4x4(
|
||||
indices.a,
|
||||
Tile4x4Accessor(lines, 0, 0),
|
||||
metadata.meta_ac_trans,
|
||||
0,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_b = pack_4x4(
|
||||
indices.b,
|
||||
Tile4x4Accessor(lines, 0, 4),
|
||||
metadata.meta_bd_trans,
|
||||
0,
|
||||
true);
|
||||
// (NOTE) Now we no longer need A & B (`lines[0:4]`)
|
||||
|
||||
// Compute packed tiles C & D
|
||||
{
|
||||
Tile4x4Packed packed_c = pack_4x4(
|
||||
indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
|
||||
Tile4x4Packed packed_d = pack_4x4(
|
||||
indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
|
||||
writePackedT(
|
||||
packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles C & D in transpose output
|
||||
Tile4x4Packed packed_trans_c = pack_4x4(
|
||||
indices.c,
|
||||
Tile4x4Accessor(lines, 4, 0),
|
||||
metadata.meta_ac_trans,
|
||||
4,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_d = pack_4x4(
|
||||
indices.d,
|
||||
Tile4x4Accessor(lines, 4, 4),
|
||||
metadata.meta_bd_trans,
|
||||
4,
|
||||
true);
|
||||
|
||||
// Store packed A, B, C & D for transposed matrix
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
|
||||
packed_trans += 4 * p.packed_trans_stride;
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at::native
|
301
aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu
Normal file
301
aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu
Normal file
@ -0,0 +1,301 @@
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <ATen/native/sparse/cuda/ComputeSparseTile.h>
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAUtils.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/types.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
struct MetadataCuSparseLt {
|
||||
// Format used by cuSparseLt
|
||||
// This is based on reverse-engineering, for a visual illustration:
|
||||
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g29afe95bda8_0_0
|
||||
static constexpr int kStrideBlock32x32 = (32 * 32) / (sizeof(ElementInputE) * 8);
|
||||
|
||||
ElementInputE* _meta;
|
||||
ElementInputE* _meta_trans;
|
||||
int64_t _rows;
|
||||
int64_t _cols;
|
||||
|
||||
static int64_t getMetadataSize(int rows, int cols)
|
||||
{
|
||||
TORCH_CHECK(rows % 128 == 0 && cols % 128 == 0, "Only supports rows/cols multiples of 128");
|
||||
// 1 bit per dense value
|
||||
return (rows * cols) / (8 * sizeof(ElementInputE));
|
||||
}
|
||||
|
||||
// < return value of the function, packed, packed_meta >
|
||||
static std::tuple<Tensor, Tensor, Tensor> create_compressed_representation(int rows, int cols, at::Tensor const& like)
|
||||
{
|
||||
TORCH_CHECK(
|
||||
like.scalar_type() == at::ScalarType::Half ||
|
||||
like.scalar_type() == at::ScalarType::BFloat16);
|
||||
constexpr int kBytesPerScalar = 2;
|
||||
int64_t data_scalars = rows * cutlass::ceil_div(cols, 2);
|
||||
int64_t meta_scalars = getMetadataSize(rows, cols);
|
||||
|
||||
at::Tensor storage = at::empty(
|
||||
{(data_scalars + meta_scalars)},
|
||||
at::TensorOptions().device(like.device()).dtype(like.dtype()));
|
||||
|
||||
using namespace torch::indexing;
|
||||
at::Tensor packed = storage.index({Slice(None, data_scalars)})
|
||||
.view({rows, cutlass::ceil_div(cols, 2)});
|
||||
at::Tensor metadata = storage.index({Slice(data_scalars, None)});
|
||||
// TODO: Cast metadata to Short
|
||||
static_assert(kBytesPerScalar == 2, "or modify the last dim below");
|
||||
metadata = metadata.view({rows / 128, cols / 32, 256});
|
||||
return std::make_tuple(storage, packed, metadata);
|
||||
}
|
||||
|
||||
MetadataCuSparseLt(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
|
||||
_meta = (ElementInputE*)metaN.data_ptr();
|
||||
_meta_trans = (ElementInputE*)metaT.data_ptr();
|
||||
_rows = rows;
|
||||
_cols = cols;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int64_t _get_meta_offset(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col,
|
||||
int totalRows) {
|
||||
int64_t offset = 0;
|
||||
// warp-level: Find the 128x64 tile
|
||||
offset += (warp_row / 128) * (kStrideBlock32x32 * 8);
|
||||
offset += (warp_col / 64) * (kStrideBlock32x32 * 8) * (totalRows / 128);
|
||||
// Find the 32x32 tile inside
|
||||
offset += (((warp_row + thread_row) % 128) / 32) * kStrideBlock32x32;
|
||||
offset += (((warp_col + thread_col) % 64) / 32) * (kStrideBlock32x32 * 4);
|
||||
// Inside the 32x32 tile
|
||||
offset += (warp_row % 32) * 2;
|
||||
// Top/bottom 16x16 tile
|
||||
offset += ((thread_row % 32) / 16) * 4;
|
||||
// Left/right 16x16 tile
|
||||
offset += ((thread_col % 32) / 16) * 2;
|
||||
return offset;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaN(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta +
|
||||
_get_meta_offset(warp_row, thread_row, warp_col, thread_col, _rows);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaT(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta_trans +
|
||||
_get_meta_offset(warp_col, thread_col, warp_row, thread_row, _cols);
|
||||
}
|
||||
};
|
||||
|
||||
struct MetadataCutlass {
|
||||
// Layout needed to run 2:4 gemms in CUTLASS
|
||||
// There is basically a hardware specific value for every
|
||||
// 32x32 dense tile (1024 bits). Then these tiles are
|
||||
// stored in a Column-Major fashion
|
||||
ElementInputE* _meta;
|
||||
ElementInputE* _meta_trans;
|
||||
int64_t _meta_reordered_sy;
|
||||
int64_t _meta_trans_reordered_sx;
|
||||
|
||||
static std::tuple<
|
||||
at::Tensor, // return value of the function
|
||||
at::Tensor, // packed
|
||||
at::Tensor // packed_meta
|
||||
>
|
||||
create_compressed_representation(int rows, int cols, at::Tensor const& like) {
|
||||
TORCH_CHECK(
|
||||
like.scalar_type() == at::ScalarType::Half ||
|
||||
like.scalar_type() == at::ScalarType::BFloat16);
|
||||
auto roundedx = cutlass::round_up(rows, kWarpX);
|
||||
auto roundedy = cutlass::round_up(cols, kWarpY);
|
||||
|
||||
// NB: Writing to `packed` tensors in transposed manner
|
||||
at::Tensor packed =
|
||||
at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, like.options());
|
||||
at::Tensor packed_meta = at::empty(
|
||||
{roundedx * roundedy / 16},
|
||||
like.options().dtype(at::ScalarType::Short))
|
||||
.view({roundedy / 32, roundedx, 2})
|
||||
.permute({1, 2, 0});
|
||||
return std::make_tuple(packed, packed, packed_meta);
|
||||
}
|
||||
MetadataCutlass(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
|
||||
_meta = (ElementInputE*)metaN.data_ptr();
|
||||
_meta_reordered_sy = metaN.stride(2);
|
||||
_meta_trans = (ElementInputE*)metaT.data_ptr();
|
||||
_meta_trans_reordered_sx = metaT.stride(2);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t _get_meta_offset(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col,
|
||||
int64_t stride) const {
|
||||
int64_t offset = 0;
|
||||
offset += warp_row * 2 + (warp_col / 32) * stride;
|
||||
// A single warp is 32x64. The right 32x32 tile is at a different position
|
||||
offset += 64 * (thread_row / 32);
|
||||
offset += (thread_col / 32) * stride;
|
||||
// Top/bottom 16x16 tile
|
||||
offset += ((thread_row % 32) / 16) * 4;
|
||||
// Left/right 16x16 tile
|
||||
offset += ((thread_col % 32) / 16) * 2;
|
||||
return offset;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaN(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta +
|
||||
_get_meta_offset(
|
||||
warp_row, thread_row, warp_col, thread_col, _meta_reordered_sy);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaT(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta_trans +
|
||||
_get_meta_offset(
|
||||
warp_col,
|
||||
thread_col,
|
||||
warp_row,
|
||||
thread_row,
|
||||
_meta_trans_reordered_sx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename KT, typename Metadata, typename Algorithm>
|
||||
__global__ void __launch_bounds__(32 /* num_threads */, 20)
|
||||
sparse_semi_structured_tile_kernel(
|
||||
typename KT::Params p,
|
||||
Metadata metadata,
|
||||
Algorithm algo) {
|
||||
KT::sparse_semi_structured_tile_kernel(p, metadata, algo);
|
||||
}
|
||||
|
||||
template <typename Element, typename MetadataFormat>
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> sparse_semi_structured_tile_typed(
|
||||
const at::Tensor input,
|
||||
std::string algorithm)
|
||||
{
|
||||
using KT = KernelTypes<Element>;
|
||||
c10::optional<at::cuda::CUDAGuard> device_guard;
|
||||
if (!input.is_meta()) {
|
||||
device_guard.emplace(input.device());
|
||||
}
|
||||
|
||||
TORCH_CHECK(input.dim() == 2, "Can only sparsify 2d tensors");
|
||||
TORCH_CHECK(
|
||||
input.stride(1) == 1,
|
||||
"Can only sparsify contiguous tensors. Sparsify the transpose otherwise.");
|
||||
|
||||
auto rows = input.size(0);
|
||||
auto cols = input.size(1);
|
||||
|
||||
auto [compressed, packed, packed_meta_reordered] =
|
||||
MetadataFormat::create_compressed_representation(rows, cols, input);
|
||||
auto [compressed_trans, packed_trans, packed_trans_meta_reordered] =
|
||||
MetadataFormat::create_compressed_representation(cols, rows, input);
|
||||
TORCH_CHECK(
|
||||
input.size(1) % 32 == 0, "Number of cols should be multiple of 32");
|
||||
|
||||
typename KT::Params p;
|
||||
p.input = (Element const*)input.data_ptr();
|
||||
p.input_s0 = input.stride(0);
|
||||
p.input_dim0 = input.size(0);
|
||||
p.input_dim1 = input.size(1);
|
||||
|
||||
p.packed = (Element*)packed.data_ptr();
|
||||
p.packed_stride = packed.stride(0);
|
||||
p.packed_trans = (Element*)packed_trans.data_ptr();
|
||||
p.packed_trans_stride = packed_trans.stride(0);
|
||||
|
||||
MetadataFormat metadata = MetadataFormat(
|
||||
packed_meta_reordered, packed_trans_meta_reordered, rows, cols);
|
||||
at::Tensor threads_masks = at::empty(
|
||||
{p.getBlocksGrid().x * p.getThreadsGrid().x,
|
||||
p.getBlocksGrid().y * p.getThreadsGrid().y,
|
||||
sizeof(p.threads_masks[0])},
|
||||
input.options().dtype(at::ScalarType::Byte));
|
||||
p.threads_masks = (uint64_t*)threads_masks.data_ptr();
|
||||
|
||||
bool kernel_launched = false;
|
||||
auto launchKernel = [&](auto algo, std::string const& algo_name) {
|
||||
if (algo_name == algorithm) {
|
||||
kernel_launched = true;
|
||||
if (input.is_meta()) {
|
||||
return;
|
||||
}
|
||||
size_t smem_bytes = 0;
|
||||
sparse_semi_structured_tile_kernel<KT>
|
||||
<<<p.getBlocksGrid(),
|
||||
p.getThreadsGrid(),
|
||||
smem_bytes,
|
||||
at::cuda::getCurrentCUDAStream()>>>(p, metadata, algo);
|
||||
}
|
||||
};
|
||||
named_algorithms(launchKernel);
|
||||
TORCH_CHECK(kernel_launched, "Unknown algorithm \"", algorithm, "\"");
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
return std::make_tuple(
|
||||
compressed,
|
||||
packed_meta_reordered,
|
||||
compressed_trans,
|
||||
packed_trans_meta_reordered,
|
||||
threads_masks);
|
||||
}
|
||||
|
||||
// <packed, packed_meta_reordered, packed_trans, packed_trans_meta_reorderd, threads_masks>
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _sparse_semi_structured_tile(
|
||||
const Tensor& input,
|
||||
c10::string_view algorithm,
|
||||
bool use_cutlass)
|
||||
{
|
||||
std::string algo(algorithm.data(), algorithm.size());
|
||||
|
||||
auto runTyped = [&](auto type)
|
||||
{
|
||||
using ElementT = decltype(type);
|
||||
if (use_cutlass) {
|
||||
return sparse_semi_structured_tile_typed<ElementT, MetadataCutlass>(input, algo);
|
||||
}
|
||||
else {
|
||||
return sparse_semi_structured_tile_typed<ElementT, MetadataCuSparseLt>(input, algo);
|
||||
}
|
||||
};
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
return runTyped(cutlass::half_t());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16, input.scalar_type());
|
||||
return runTyped(cutlass::bfloat16_t());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -0,0 +1,97 @@
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
template <typename KT>
|
||||
__global__ void __launch_bounds__(32 /* num_threads */)
|
||||
sparse_semi_structured_apply_kernel(typename KT::Params p)
|
||||
{
|
||||
KT::sparse_semi_structured_apply_kernel(p);
|
||||
}
|
||||
|
||||
// Apply a 2:4 sparsify pattern computed with
|
||||
// `_sparse_semi_structured_tile` to another Tensor
|
||||
template <bool kIsMeta, typename Element>
|
||||
std::tuple<Tensor, Tensor> _sparse_semi_structured_apply_typed(Tensor input, Tensor threads_masks)
|
||||
{
|
||||
using KT = KernelTypes<Element>;
|
||||
// TODO: Technically we should be able to deal with that
|
||||
// by running on the transpose of `input` and swapping
|
||||
// `packed` & `packed_t`.
|
||||
// This would require to adapt the `threads_masks` a bit tho.
|
||||
if (input.stride(1) != 1) {
|
||||
input = input.contiguous();
|
||||
}
|
||||
c10::optional<at::cuda::CUDAGuard> device_guard;
|
||||
if (!kIsMeta) {
|
||||
device_guard.emplace(input.device());
|
||||
}
|
||||
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input.stride(1) == 1);
|
||||
TORCH_CHECK(input.stride(0) % 8 == 0);
|
||||
TORCH_CHECK(input.size(1) % 32 == 0, "Wrong alignment shape[1]");
|
||||
|
||||
auto roundedx = cutlass::round_up(input.size(0), kWarpX);
|
||||
auto roundedy = cutlass::round_up(input.size(1), kWarpY);
|
||||
at::Tensor packed =
|
||||
at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, input.options());
|
||||
at::Tensor packed_trans =
|
||||
at::empty({roundedy, cutlass::ceil_div(roundedx, 2)}, input.options());
|
||||
|
||||
typename KT::Params p;
|
||||
p.input = (Element const*)input.data_ptr();
|
||||
p.input_s0 = input.stride(0);
|
||||
p.input_dim0 = input.size(0);
|
||||
p.input_dim1 = input.size(1);
|
||||
|
||||
p.packed = (Element*)packed.data_ptr();
|
||||
p.packed_stride = packed.stride(0);
|
||||
p.packed_trans = (Element*)packed_trans.data_ptr();
|
||||
p.packed_trans_stride = packed_trans.stride(0);
|
||||
|
||||
p.threads_masks = (uint64_t*)threads_masks.data_ptr();
|
||||
|
||||
TORCH_CHECK(threads_masks.dim() == 3);
|
||||
TORCH_CHECK(
|
||||
threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
|
||||
TORCH_CHECK(
|
||||
threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
|
||||
TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.stride(2) == 1);
|
||||
TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
|
||||
|
||||
if (!kIsMeta) {
|
||||
size_t smem_bytes = 0;
|
||||
sparse_semi_structured_apply_kernel<KT>
|
||||
<<<p.getBlocksGrid(),
|
||||
p.getThreadsGrid(),
|
||||
smem_bytes,
|
||||
at::cuda::getCurrentCUDAStream()>>>(p);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
return std::make_tuple(packed, packed_trans);
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Tensor, Tensor> _sparse_semi_structured_apply(const Tensor& input, const Tensor& threads_masks) // Returned by `_sparse_semi_structured_tile`
|
||||
{
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported dtype - only `float16` and `bfloat16` are supported currently"
|
||||
);
|
||||
auto result = (input.scalar_type() == at::ScalarType::Half)
|
||||
? _sparse_semi_structured_apply_typed<false, cutlass::half_t>(input, threads_masks)
|
||||
: _sparse_semi_structured_apply_typed<false, cutlass::bfloat16_t>(input, threads_masks);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
100
aten/src/ATen/native/sparse/cuda/StaticSort.h
Normal file
100
aten/src/ATen/native/sparse/cuda/StaticSort.h
Normal file
@ -0,0 +1,100 @@
|
||||
#pragma once
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
/**
|
||||
* A Functor class to create a sort for fixed sized arrays/containers with a
|
||||
* compile time generated Bose-Nelson sorting network.
|
||||
* \tparam NumElements The number of elements in the array or container to
|
||||
* sort. \tparam T The element type. \tparam Compare A
|
||||
* comparator functor class that returns true if lhs < rhs.
|
||||
*/
|
||||
template <unsigned NumElements>
|
||||
class StaticSort {
|
||||
template <class A>
|
||||
struct Swap {
|
||||
template <class T>
|
||||
CUTLASS_HOST_DEVICE void s(T& v0, T& v1) {
|
||||
// Explicitly code out the Min and Max to nudge the compiler
|
||||
// to generate branchless code.
|
||||
T t = v0 < v1 ? v0 : v1; // Min
|
||||
v1 = v0 < v1 ? v1 : v0; // Max
|
||||
v0 = t;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE Swap(A& a, const int& i0, const int& i1) {
|
||||
s(a[i0], a[i1]);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J, int X, int Y>
|
||||
struct PB {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
enum {
|
||||
L = X >> 1,
|
||||
M = (X & 1 ? Y : Y + 1) >> 1,
|
||||
IAddL = I + L,
|
||||
XSubL = X - L
|
||||
};
|
||||
PB<A, I, J, L, M> p0(a);
|
||||
PB<A, IAddL, J + M, XSubL, Y - M> p1(a);
|
||||
PB<A, IAddL, J, XSubL, M> p2(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J>
|
||||
struct PB<A, I, J, 1, 1> {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
Swap<A> s(a, I - 1, J - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J>
|
||||
struct PB<A, I, J, 1, 2> {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
Swap<A> s0(a, I - 1, J);
|
||||
Swap<A> s1(a, I - 1, J - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J>
|
||||
struct PB<A, I, J, 2, 1> {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
Swap<A> s0(a, I - 1, J - 1);
|
||||
Swap<A> s1(a, I, J - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int M, bool Stop = false>
|
||||
struct PS {
|
||||
CUTLASS_HOST_DEVICE PS(A& a) {
|
||||
enum { L = M >> 1, IAddL = I + L, MSubL = M - L };
|
||||
PS<A, I, L, (L <= 1)> ps0(a);
|
||||
PS<A, IAddL, MSubL, (MSubL <= 1)> ps1(a);
|
||||
PB<A, I, IAddL, L, MSubL> pb(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int M>
|
||||
struct PS<A, I, M, true> {
|
||||
CUTLASS_HOST_DEVICE PS(A& a) {}
|
||||
};
|
||||
|
||||
public:
|
||||
/**
|
||||
* Sorts the array/container arr.
|
||||
* \param arr The array/container to be sorted.
|
||||
*/
|
||||
template <class Container>
|
||||
CUTLASS_HOST_DEVICE void operator()(Container& arr) const {
|
||||
PS<Container, 1, NumElements, (NumElements <= 1)> ps(arr);
|
||||
};
|
||||
|
||||
/**
|
||||
* Sorts the array arr.
|
||||
* \param arr The array to be sorted.
|
||||
*/
|
||||
template <class T>
|
||||
CUTLASS_HOST_DEVICE void operator()(T* arr) const {
|
||||
PS<T*, 1, NumElements, (NumElements <= 1)> ps(arr);
|
||||
};
|
||||
};
|
@ -523,7 +523,10 @@ aten::_sparse_mask_projection
|
||||
aten::_sparse_mask_projection.out
|
||||
aten::_sparse_mm_reduce_impl
|
||||
aten::_sparse_mm_reduce_impl_backward
|
||||
aten::_sparse_semi_structured_apply
|
||||
aten::_sparse_semi_structured_apply_dense
|
||||
aten::_sparse_semi_structured_linear
|
||||
aten::_sparse_semi_structured_tile
|
||||
aten::_sparse_softmax
|
||||
aten::_sparse_softmax.out
|
||||
aten::_sparse_softmax_backward_data
|
||||
|
@ -6,6 +6,7 @@ import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.sparse import (
|
||||
SparseSemiStructuredTensor,
|
||||
@ -14,6 +15,12 @@ from torch.sparse import (
|
||||
to_sparse_semi_structured,
|
||||
)
|
||||
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
_sparse_semi_structured_tile,
|
||||
_compute_compressed_swizzled_bitmask,
|
||||
)
|
||||
|
||||
from torch.testing import make_tensor
|
||||
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -33,28 +40,48 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
CUSPARSELT_NUM_ALG_IDS = 4
|
||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = {}
|
||||
|
||||
_IS_SM8X = False
|
||||
|
||||
if torch.cuda.is_available():
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
||||
|
||||
# check if cslt is available for now using this:
|
||||
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
||||
try:
|
||||
torch._cslt_compress(torch.ones(128, 256).cuda())
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
|
||||
training_dtypes = dtypes(torch.float16, torch.bfloat16)
|
||||
parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
|
||||
atol_rtol_kw = {
|
||||
torch.float16: {
|
||||
"rtol": 1e-3,
|
||||
"atol": 1e-3,
|
||||
},
|
||||
torch.bfloat16: {
|
||||
"rtol": 1e-1,
|
||||
"atol": 1e-1,
|
||||
},
|
||||
}
|
||||
|
||||
def sparse24_largest_mask_2d(original):
|
||||
sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original)
|
||||
return sparse.to_dense().bool()
|
||||
|
||||
def sparsify24_dense(original):
|
||||
return sparse24_largest_mask_2d(original) * original
|
||||
|
||||
def rand_sparse_semi_structured_mask(
|
||||
r, c, dtype=torch.float16, device="cuda", choice=None
|
||||
@ -98,6 +125,7 @@ def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
|
||||
dense = dense.masked_fill(~mask, 0)
|
||||
return dense
|
||||
|
||||
|
||||
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
|
||||
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
||||
if pattern == '1by2':
|
||||
@ -172,8 +200,6 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
x = x.contiguous()
|
||||
return torch.nn.functional.relu(x)
|
||||
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
|
||||
input = torch.rand(dense_input_shape, device="cuda").half()
|
||||
model = Model().eval().cuda().half()
|
||||
mod_linear = model.linear
|
||||
@ -183,7 +209,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
|
||||
|
||||
dense_result = model(input)
|
||||
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
|
||||
mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight))
|
||||
sparse_result = model(input)
|
||||
|
||||
model = torch.compile(model, backend="inductor", fullgraph=True)
|
||||
@ -216,20 +242,32 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
|
||||
|
||||
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
||||
|
||||
def fn(x, e):
|
||||
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
||||
y = y.t()
|
||||
return x @ y
|
||||
|
||||
# Eager
|
||||
output = fn(x, e)
|
||||
output.backward(output)
|
||||
# Torch compile
|
||||
output = torch.compile(fn)(x, e)
|
||||
output.backward(output)
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
def test_to_sparse_semi_structured(self, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@ -240,18 +278,14 @@ class TestSparseSemiStructured(TestCase):
|
||||
assert isinstance(A, torch.Tensor)
|
||||
assert isinstance(A_sparse, SparseSemiStructuredTensor)
|
||||
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@ -259,7 +293,6 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
# This should fail
|
||||
if backend == "cutlass":
|
||||
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
@ -272,18 +305,15 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
|
||||
and will throw an error for int8 + padding
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@ -311,9 +341,9 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@inference_dtypes
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse.t(), B) throws error
|
||||
@ -332,9 +362,9 @@ class TestSparseSemiStructured(TestCase):
|
||||
):
|
||||
torch.mm(A_sparse.t(), B)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@inference_dtypes
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse.t()) is correct
|
||||
@ -357,9 +387,9 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@inference_dtypes
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse) throws error
|
||||
@ -380,7 +410,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
@parametrize("inference_mode", [subtest(True), subtest(False)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_linear(self, dense_input_shape, inference_mode, device, backend):
|
||||
"""
|
||||
Test nn.Linear has the same numerics
|
||||
@ -408,11 +438,9 @@ class TestSparseSemiStructured(TestCase):
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_mlp(self, device, dense_input_shape, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
input = torch.rand(dense_input_shape, device=device).half()
|
||||
model = (
|
||||
nn.Sequential(
|
||||
@ -440,7 +468,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_values(self, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -450,7 +478,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
assert A_sparse.values().shape == (128, 64)
|
||||
assert (A_sparse.values() == 1).all()
|
||||
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_indices(self, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -459,16 +487,11 @@ class TestSparseSemiStructured(TestCase):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
assert A_sparse.indices().shape == (128, 8)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
def test_min_sparse_shape(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
if backend == "cutlass":
|
||||
config = SparseSemiStructuredTensorCUTLASS._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
elif backend == "cusparselt":
|
||||
config = SparseSemiStructuredTensorCUSPARSELT._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
||||
@ -482,8 +505,8 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_res = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
def test_unsupported_shape(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -493,7 +516,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@dtypes(*all_types_and_complex())
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_unsupported_dtype(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -506,7 +529,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
else:
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize_backends
|
||||
def test_unsupported_dim(self, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -516,13 +539,323 @@ class TestSparseSemiStructured(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_linear_cutlass(self, device, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
def create_random_mask(shape) -> torch.Tensor:
|
||||
r = random.Random(0)
|
||||
mask = torch.zeros(shape, dtype=torch.bool)
|
||||
for line in range(mask.shape[0]):
|
||||
for col in range(0, mask.shape[1], 4):
|
||||
sparsity = r.choice(
|
||||
[
|
||||
[False, False, True, True],
|
||||
[False, True, False, True],
|
||||
[True, False, False, True],
|
||||
[False, True, True, False],
|
||||
[True, False, True, False],
|
||||
[True, True, False, False],
|
||||
]
|
||||
)
|
||||
mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool)
|
||||
return mask
|
||||
|
||||
class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
|
||||
|
||||
@training_dtypes
|
||||
def test_prune_dense_static_sort(self, dtype) -> None:
|
||||
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
||||
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
||||
dense = torch.randn(128, 128, device="cuda", dtype=dtype)
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
|
||||
# CUTLASS
|
||||
reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
|
||||
assert torch.allclose(pruned, reference_cutlass.to_dense())
|
||||
|
||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
||||
meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride())
|
||||
meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride())
|
||||
compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape,
|
||||
reference_cutlass.compressed_swizzled_bitmask.stride())
|
||||
cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape,
|
||||
packed_cutlass,
|
||||
meta_cutlass,
|
||||
packed_t_cutlass,
|
||||
meta_t_cutlass,
|
||||
compressed_swizzled_bitmask)
|
||||
assert torch.allclose(reference_cutlass.to_dense(), cutlass.to_dense())
|
||||
|
||||
# CUSPARSELT
|
||||
reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
|
||||
algorithm="largest_abs_values_greedy")
|
||||
assert torch.allclose(pruned, reference_cusparselt.to_dense())
|
||||
|
||||
packed_cusparselt = torch._cslt_compress(pruned)
|
||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||
cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape,
|
||||
packed_cusparselt,
|
||||
None,
|
||||
packed_t_cusparselt,
|
||||
None,
|
||||
compressed_swizzled_bitmask)
|
||||
assert torch.allclose(reference_cusparselt.to_dense(), cusparselt.to_dense())
|
||||
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
||||
inp = torch.tensor(
|
||||
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
)
|
||||
inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1)
|
||||
sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy")
|
||||
|
||||
mask = sInp.to_dense() / inp
|
||||
assert mask[:4, :4].int().tolist() == [
|
||||
[1, 1, 0, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[1, 0, 0, 1],
|
||||
]
|
||||
|
||||
@training_dtypes
|
||||
def test_gemm(self, dtype) -> None:
|
||||
M, N, K = 32, 32, 64
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
||||
mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool)
|
||||
|
||||
a.masked_fill_(~mask, 0)
|
||||
|
||||
a_sparse = to_sparse_semi_structured(a)
|
||||
|
||||
masked_a = a * mask
|
||||
ref_out = masked_a @ b
|
||||
sp24_out = a_sparse @ b
|
||||
assert torch.allclose(ref_out, sp24_out, **atol_rtol_kw[dtype])
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
||||
M, N = 128, 256
|
||||
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
||||
a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :]
|
||||
a = a.repeat(M // 8, N // 8)
|
||||
assert a.shape == (M, N)
|
||||
a = a.cuda().to(dtype)
|
||||
b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype)
|
||||
|
||||
a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a)
|
||||
|
||||
mask_dense = sparse24_largest_mask_2d(a).to(dtype)
|
||||
|
||||
if backend == "cutlass":
|
||||
assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS)
|
||||
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(
|
||||
mask_dense, use_cutlass=True)
|
||||
|
||||
sparse_mask = SparseSemiStructuredTensorCUTLASS(
|
||||
mask_dense.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
assert torch.allclose(a_sparse.meta.view(torch.short), sparse_mask.meta)
|
||||
|
||||
ref_gemm = (mask_dense * a) @ b
|
||||
pack_gemm = a_sparse @ b
|
||||
assert torch.allclose(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
|
||||
|
||||
@training_dtypes
|
||||
def test_pack_both_ways_id(self, dtype) -> None:
|
||||
N = 512
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn([N, N], dtype=dtype, device="cuda")
|
||||
b = torch.eye(N, dtype=dtype, device="cuda")
|
||||
|
||||
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[
|
||||
:4
|
||||
]
|
||||
# Heuristic to ensure we pack the same values
|
||||
assert torch.allclose(
|
||||
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
|
||||
)
|
||||
|
||||
mask_dense = sparse24_largest_mask_2d(a.to(dtype))
|
||||
|
||||
ref_gemm = mask_dense * a
|
||||
# Test A@B
|
||||
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
|
||||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
assert torch.allclose(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
# Test A.t@B
|
||||
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
|
||||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
|
||||
assert torch.allclose(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
|
||||
@training_dtypes
|
||||
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
||||
# In this case, the heuristic will keep 7 values out of 16
|
||||
# instead of 8. let's see how the kernel handles this
|
||||
quad = torch.tensor(
|
||||
[
|
||||
[2, -1, -2, -3], # Should be packed as `2 <null>`
|
||||
[-1, 8, -1, 6],
|
||||
[-1, -1, 4, 5],
|
||||
[-1, 3, 7, -1],
|
||||
],
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
)
|
||||
a = torch.randn([32, 64], dtype=dtype, device="cuda")
|
||||
a[:4, :4] = quad
|
||||
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4]
|
||||
# Check first line in A
|
||||
assert packed[0, 0].item() == 2
|
||||
assert packed[0, 1].item() == 0
|
||||
# And first column in A.t
|
||||
assert packed_t[0, 0].item() == 2
|
||||
assert packed_t[0, 1].item() == 0
|
||||
|
||||
@training_dtypes
|
||||
def test_sp24_apply(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
bitmask,
|
||||
) = torch._sparse_semi_structured_tile(x)
|
||||
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
||||
assert torch.allclose(packed, packed2)
|
||||
assert torch.allclose(packed_t, packed_t2)
|
||||
|
||||
@training_dtypes
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
bitmask,
|
||||
) = torch._sparse_semi_structured_tile(x)
|
||||
|
||||
expected = SparseSemiStructuredTensorCUTLASS(
|
||||
x.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
).to_dense()
|
||||
|
||||
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
||||
sparse = SparseSemiStructuredTensorCUTLASS(
|
||||
x.shape,
|
||||
packed=packed2,
|
||||
meta=meta,
|
||||
packed_t=packed_t2,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
|
||||
dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
|
||||
|
||||
assert torch.allclose(dense, expected)
|
||||
assert torch.allclose(sparse.to_dense(), expected)
|
||||
|
||||
|
||||
@training_dtypes
|
||||
def test_sp24_matmuls(self, dtype) -> None:
|
||||
M, N, K = 64, 256, 1024
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
||||
a_m = sparse24_largest_mask_2d(a)
|
||||
b_m = sparse24_largest_mask_2d(b)
|
||||
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a)
|
||||
a_s = SparseSemiStructuredTensorCUTLASS(
|
||||
a.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b)
|
||||
b_s = SparseSemiStructuredTensorCUTLASS(
|
||||
b.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1e-1)
|
||||
assert torch.allclose(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1e-1)
|
||||
assert torch.allclose(
|
||||
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1e-1
|
||||
)
|
||||
assert torch.allclose(
|
||||
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
|
||||
)
|
||||
|
||||
def test_sp24_matmuls_mat_vec(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
||||
a_m = sparse24_largest_mask_2d(a)
|
||||
a_s = to_sparse_semi_structured(a)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
||||
a_m = sparse24_largest_mask_2d(a)
|
||||
a_s = to_sparse_semi_structured(a)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
"""
|
||||
This contains CUTLASS specific tests for
|
||||
- torch._sparse_semi_structured_linear
|
||||
"""
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
self.skipTest('CUTLASS not enabled')
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@inference_dtypes
|
||||
def test_linear_cutlass(self, device, dtype):
|
||||
|
||||
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
|
||||
weight = rand_sparse_semi_structured(m, k, dtype, device)
|
||||
@ -579,12 +912,8 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_conversions(self, device, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
@inference_dtypes
|
||||
def test_conversions(self, device, dtype):
|
||||
|
||||
def run_test(r, c, device, dtype):
|
||||
dense_ref = rand_sparse_semi_structured(r, c, dtype, device)
|
||||
@ -611,12 +940,8 @@ class TestSparseSemiStructured(TestCase):
|
||||
run_test(r, c, device, dtype)
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_conversions_all_patterns(self, device, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
@inference_dtypes
|
||||
def test_conversions_all_patterns(self, device, dtype):
|
||||
r, c = 32, 128
|
||||
|
||||
dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
|
||||
@ -626,18 +951,23 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
||||
|
||||
class TestCUSPARSELT(TestCase):
|
||||
"""
|
||||
This contains cuSPARSELt specific tests.
|
||||
"""
|
||||
|
||||
|
||||
CUSPARSELT_NUM_ALG_IDS = 4
|
||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
|
||||
|
||||
class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
"""
|
||||
This contains cuSPARSELt specific tests for
|
||||
torch._cslt_compress
|
||||
torch._cslt_sparse_mm
|
||||
"""
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
self.skipTest('cuSPARSELt not enabled')
|
||||
else:
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||
|
||||
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
||||
@parametrize("dense_input_shape", [(128, 128)])
|
||||
@ -651,7 +981,7 @@ class TestCUSPARSELT(TestCase):
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
@training_dtypes
|
||||
def test_cslt_sparse_mm_alpha(self, dtype, device):
|
||||
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
|
||||
B = torch.ones((256, 128), device=device).to(dtype)
|
||||
@ -683,7 +1013,7 @@ class TestCUSPARSELT(TestCase):
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@inference_dtypes
|
||||
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
|
||||
# alg_id=3 not supported for float32 dtype
|
||||
if dtype == torch.float32 and alg_id == 3:
|
||||
@ -700,7 +1030,7 @@ class TestCUSPARSELT(TestCase):
|
||||
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@inference_dtypes
|
||||
def test_cslt_sparse_mm_search(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
@ -713,9 +1043,10 @@ class TestCUSPARSELT(TestCase):
|
||||
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
||||
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestCUSPARSELT, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,20 +1,22 @@
|
||||
import torch
|
||||
|
||||
|
||||
# This is PyTorch implementation of main part of reorder_meta()
|
||||
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
||||
# GEMM decides upon layout of this matrix, and at the moment for the
|
||||
# sparse GEMM executed on tensor cores, this is layout described by
|
||||
# ColumnMajorInterleaved<2> data structure, in
|
||||
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
||||
# reordering of meta matrix into meta_reordered matrix calculated
|
||||
# according to these segments of CUTLASS code is re-implemented here.
|
||||
# Note that this calculation produces offsets for scattering metadata
|
||||
# matrix elements into reordered metadata matrix elements (or,
|
||||
# equivalently, for gathering reordered metadata matrix element back
|
||||
# into metadata matrix elements).
|
||||
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
||||
"""
|
||||
This is PyTorch implementation of main part of reorder_meta()
|
||||
function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
||||
GEMM decides upon layout of this matrix, and at the moment for the
|
||||
sparse GEMM executed on tensor cores, this is layout described by
|
||||
ColumnMajorInterleaved<2> data structure, in
|
||||
include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
||||
reordering of meta matrix into meta_reordered matrix calculated
|
||||
according to these segments of CUTLASS code is re-implemented here.
|
||||
Note that this calculation produces offsets for scattering metadata
|
||||
matrix elements into reordered metadata matrix elements (or,
|
||||
equivalently, for gathering reordered metadata matrix element back
|
||||
into metadata matrix elements).
|
||||
"""
|
||||
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
||||
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
||||
|
||||
@ -41,10 +43,12 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device
|
||||
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
||||
|
||||
|
||||
# This function converts dense matrix into sparse semi-structured
|
||||
# representation, producing "compressed" matrix, in the layout used by
|
||||
# CUTLASS backend, and corresponding metadata matrix.
|
||||
def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
"""
|
||||
This function converts dense matrix into sparse semi-structured
|
||||
representation, producing "compressed" matrix, in the layout used by
|
||||
CUTLASS backend, and corresponding metadata matrix.
|
||||
"""
|
||||
if dense.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
|
||||
@ -172,11 +176,13 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||
|
||||
|
||||
# This function performs reverse of the function above - it
|
||||
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
# in the layout used by CUTLASS backend, and accompanying metadata
|
||||
# matrix.
|
||||
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
"""
|
||||
This function performs reverse of the function above - it
|
||||
reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
in the layout used by CUTLASS backend, and accompanying metadata
|
||||
matrix.
|
||||
"""
|
||||
if sparse.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
|
||||
@ -273,3 +279,73 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
)
|
||||
|
||||
return dense.view(m, 2 * k)
|
||||
|
||||
|
||||
def _sparse_semi_structured_tile(dense):
|
||||
"""
|
||||
This function computes a 2:4 sparse tile by greedily taking the largest values.
|
||||
|
||||
Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
|
||||
the ultimate sparsity pattern.
|
||||
|
||||
Note that this function does not have the same sorting semantics as our CUDA backend,
|
||||
which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
|
||||
"""
|
||||
|
||||
def greedy_prune_tile(tile):
|
||||
num_kept_row = [0, 0, 0, 0]
|
||||
num_kept_col = [0, 0, 0, 0]
|
||||
|
||||
for x in tile.flatten().sort(descending=True, stable=True).indices:
|
||||
r, c = x // 4, x % 4
|
||||
if num_kept_row[r] < 2 and num_kept_col[c] < 2:
|
||||
num_kept_row[r] += 1
|
||||
num_kept_col[c] += 1
|
||||
else:
|
||||
tile[r, c] = 0
|
||||
|
||||
for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
|
||||
for tile in batch:
|
||||
greedy_prune_tile(tile)
|
||||
|
||||
return dense
|
||||
|
||||
|
||||
def _compute_compressed_swizzled_bitmask(dense):
|
||||
"""
|
||||
Calculates the compressed swizzled bitmask from a dense tensor
|
||||
"""
|
||||
|
||||
# first we need to convert the dense tensor to a bitmask
|
||||
int_bitmask = dense.bool().to(torch.uint8)
|
||||
|
||||
# Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
|
||||
# A, B, C and D, as displayed in the following schema:
|
||||
# +---+---+
|
||||
# | A | B |
|
||||
# +---+---+
|
||||
# | C | D |
|
||||
# +---+---+
|
||||
|
||||
# we first need to split into the 8x8 tiles
|
||||
bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
|
||||
|
||||
# then we unfold again to get our indivdual 4x4 tiles
|
||||
bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
|
||||
|
||||
# Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
|
||||
# of that tile. Note that the least siginificant bit is stored first.
|
||||
# [1 1 0 0]
|
||||
# [1 1 0 0] -> 0011 0011 -> 51
|
||||
# [0 0 1 1] 1100 1100 204
|
||||
# [0 0 1 1]
|
||||
|
||||
# reshape tensor to expand tiles into 8-bit vectors
|
||||
bitmask_binary_representation = bitmask_4x4_chunks.reshape(*bitmask_4x4_chunks.shape[:2], 4, 2, 8)
|
||||
|
||||
# to convert from binary representaiton, we can do a matmul with powers of two
|
||||
powers_of_two = 2**torch.arange(8, dtype=torch.float, device="cuda")
|
||||
# To run on GPU: cast to float to do matmul and then cast back
|
||||
compressed_swizzled_bitmask = (bitmask_binary_representation.to(torch.float) @ powers_of_two).to(torch.uint8)
|
||||
|
||||
return compressed_swizzled_bitmask
|
||||
|
@ -70,8 +70,8 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
meta=self.meta_t,
|
||||
packed_t=self.packed,
|
||||
meta_t=self.meta,
|
||||
threads_masks=self.threads_masks.transpose(0, 1)
|
||||
if self.threads_masks is not None
|
||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
|
||||
if self.compressed_swizzled_bitmask is not None
|
||||
else None,
|
||||
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=args[0].alg_id_cusparselt,
|
||||
@ -97,7 +97,7 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
|
||||
meta=self.meta,
|
||||
packed_t=self.packed_t,
|
||||
meta_t=self.meta_t,
|
||||
threads_masks=self.threads_masks,
|
||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, List, Callable, Dict
|
||||
import torch
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass
|
||||
)
|
||||
from torch.sparse._semi_structured_ops import (
|
||||
fallback_dispatcher,
|
||||
@ -56,17 +56,18 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
_FUSE_TRANSPOSE: bool = False
|
||||
_PROTOTYPE_WARNING_SHOWN: bool = False
|
||||
|
||||
BACKEND: str
|
||||
SPARSE_DISPATCH: Dict[Callable, Callable]
|
||||
|
||||
packed: Optional[torch.Tensor]
|
||||
meta: Optional[torch.Tensor]
|
||||
packed_t: Optional[torch.Tensor]
|
||||
meta_t: Optional[torch.Tensor]
|
||||
threads_masks: Optional[torch.Tensor]
|
||||
compressed_swizzled_bitmask: Optional[torch.Tensor]
|
||||
fuse_transpose_cusparselt: bool
|
||||
alg_id_cusparselt: int
|
||||
|
||||
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"]
|
||||
__slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
|
||||
|
||||
@staticmethod
|
||||
def __new__( # noqa: PYI034
|
||||
@ -76,7 +77,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta: Optional[torch.Tensor],
|
||||
packed_t: Optional[torch.Tensor],
|
||||
meta_t: Optional[torch.Tensor],
|
||||
threads_masks: Optional[torch.Tensor],
|
||||
compressed_swizzled_bitmask: Optional[torch.Tensor],
|
||||
fuse_transpose_cusparselt: bool = False,
|
||||
alg_id_cusparselt: int = 0,
|
||||
requires_grad: bool = False,
|
||||
@ -95,8 +96,8 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta: The metadata of the original dense tensor, if it is stored separately
|
||||
packed_t: The compressed representation of the transposed original dense tensor
|
||||
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
|
||||
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation.
|
||||
Used for pointwise ops.
|
||||
compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
|
||||
participate in the computation. Used for pointwise ops.
|
||||
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
|
||||
with a matmul, which is useful in the case of 2:4 sparse training.
|
||||
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
|
||||
@ -124,6 +125,9 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
|
||||
cls._load_dispatch_table()
|
||||
|
||||
# we can also register the classes with dynamo when the warning is shown.
|
||||
torch._dynamo.allow_in_graph(cls)
|
||||
|
||||
if packed is not None:
|
||||
previous_tensor = packed
|
||||
elif packed_t is not None:
|
||||
@ -143,7 +147,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
tensor.meta = meta
|
||||
tensor.packed_t = packed_t
|
||||
tensor.meta_t = meta_t
|
||||
tensor.threads_masks = threads_masks
|
||||
tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
|
||||
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
|
||||
tensor.alg_id_cusparselt = alg_id_cusparselt
|
||||
return tensor
|
||||
@ -181,7 +185,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta=inner_tensors.get("meta", None),
|
||||
packed_t=inner_tensors.get("packed_t", None),
|
||||
meta_t=inner_tensors.get("meta_t", None),
|
||||
threads_masks=inner_tensors.get("threads_masks", None),
|
||||
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
|
||||
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=alg_id_cusparselt,
|
||||
requires_grad=requires_grad,
|
||||
@ -216,6 +220,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
torch.ops.aten.matmul: semi_sparse_mm,
|
||||
torch.ops.aten.addmm: semi_sparse_addmm,
|
||||
torch.ops.aten.linear: semi_sparse_linear,
|
||||
torch.ops.aten._to_copy: fallback_dispatcher,
|
||||
}
|
||||
if custom_dispatch_table is not None:
|
||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||
@ -359,13 +364,14 @@ def to_sparse_semi_structured(
|
||||
"SparseSemiStructuredTensor only support contiguous input tensors. "
|
||||
)
|
||||
|
||||
sparse_subclass = (
|
||||
# set from _FORCE_CUTLASS flag
|
||||
SPARSE_SUBCLASS = (
|
||||
torch.sparse.SparseSemiStructuredTensorCUTLASS
|
||||
if SparseSemiStructuredTensor._FORCE_CUTLASS
|
||||
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
|
||||
)
|
||||
return sparse_subclass.from_dense(original_tensor)
|
||||
|
||||
return SPARSE_SUBCLASS.from_dense(original_tensor)
|
||||
|
||||
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
"""
|
||||
@ -377,7 +383,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
|
||||
and sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
"""
|
||||
|
||||
BACKEND = "cutlass"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
||||
@ -400,19 +406,71 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
meta=meta_tensor_cutlass,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
threads_masks=None,
|
||||
compressed_swizzled_bitmask=None,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
def to_dense(self):
|
||||
assert self.meta is not None and self.packed is not None
|
||||
return (
|
||||
sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
)
|
||||
if self.meta.ndim == 2
|
||||
else super().to_dense()
|
||||
return sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
) if self.meta.ndim == 2 else super().to_dense()
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
|
||||
|
||||
It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
|
||||
The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
|
||||
|
||||
Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
|
||||
It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
|
||||
pruned dense tensor.
|
||||
Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
|
||||
|
||||
Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
|
||||
This can be used in the backward pass to mask the gradients.
|
||||
|
||||
[9 1 7 4] [9 0 7 0]
|
||||
[1 2 3 0] [0 2 0 0]
|
||||
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
|
||||
[1 2 6 2] [0 0 6 2] -> metadata
|
||||
|
||||
-> pack to transposed CUTLASS -> packed_t
|
||||
semi-structured representation -> metadata_t
|
||||
|
||||
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
||||
|
||||
|
||||
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUTLASS
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
|
||||
```
|
||||
"""
|
||||
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=True)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def _mm(
|
||||
@ -453,7 +511,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
|
||||
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
|
||||
"""
|
||||
|
||||
BACKEND = "cusparselt"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||
@ -470,12 +528,59 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
meta=None,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
threads_masks=None,
|
||||
compressed_swizzled_bitmask=None,
|
||||
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
|
||||
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
|
||||
layout and sparse matmul.
|
||||
|
||||
The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
|
||||
|
||||
[9 1 7 4] [9 0 7 0]
|
||||
[1 2 3 0] [0 2 0 0]
|
||||
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
|
||||
[1 2 6 2] [0 0 6 2]
|
||||
|
||||
-> pack to transposed cuSPARSELt -> packed_t
|
||||
semi-structured representation
|
||||
|
||||
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
||||
|
||||
|
||||
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cusparselt = torch._cslt_compress(pruned)
|
||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
||||
```
|
||||
"""
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=False)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user