mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[sparse] Add fast semi-structured spasification kernels (#122350)"
This reverts commit 14b2273b0c58b4000e10b2e441341eeafb7dd2f6. Reverted https://github.com/pytorch/pytorch/pull/122350 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/122350#issuecomment-2061070350))
This commit is contained in:
@ -3342,18 +3342,6 @@
|
||||
dispatch:
|
||||
CUDA: _cslt_sparse_mm_search
|
||||
|
||||
- func: _sparse_semi_structured_tile(Tensor input, str algorithm="", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_tile
|
||||
|
||||
- func: _sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_apply
|
||||
|
||||
- func: _sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _sparse_semi_structured_apply_dense
|
||||
|
||||
# DEPRECATED: Use torch.__sparse_semi_structured_mm/torch._sparse_semi_structured_addmm instead
|
||||
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor
|
||||
dispatch:
|
||||
|
@ -1,184 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#include <ATen/native/sparse/cuda/StaticSort.h>
|
||||
#include <cutlass/bfloat16.h>
|
||||
#include <cutlass/half.h>
|
||||
|
||||
// Given 4x4 values, computes the selected indices that will remain after 2:4
|
||||
// sparsification, as a bitmask.
|
||||
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
|
||||
|
||||
namespace platform {
|
||||
template <>
|
||||
struct numeric_limits<cutlass::bfloat16_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t infinity() {
|
||||
return cutlass::bfloat16_t::bitcast(0x7f80);
|
||||
}
|
||||
};
|
||||
} // namespace platform
|
||||
|
||||
namespace at::native{
|
||||
|
||||
template <typename Element, typename Pointwise>
|
||||
struct TileValueOrderedT {
|
||||
union {
|
||||
struct {
|
||||
Element value;
|
||||
uint2b_t col;
|
||||
uint2b_t row;
|
||||
} parts;
|
||||
uint32_t raw;
|
||||
};
|
||||
CUTLASS_DEVICE bool operator<(
|
||||
TileValueOrderedT<Element, Pointwise> const& other) const {
|
||||
return Pointwise::apply(parts.value) < Pointwise::apply(other.parts.value);
|
||||
}
|
||||
CUTLASS_DEVICE TileValueOrderedT() {}
|
||||
};
|
||||
|
||||
// Operations that we can apply to rank the values
|
||||
struct IdentityOp {
|
||||
template <typename T>
|
||||
static T CUTLASS_HOST_DEVICE apply(T const& x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
// Can be applied to rank based on absolute value
|
||||
struct AbsOp {
|
||||
template <typename T>
|
||||
static T CUTLASS_HOST_DEVICE apply(T const& x) {
|
||||
return cutlass::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
// Given 4x4 values, computes the selected indices that will remain after 2:4
|
||||
// sparsification, as a bitmask. We have 2 constraints:
|
||||
// (1) At most 2 values per line
|
||||
// (2) At most 2 values per column
|
||||
// This means we can select at most 8 values in total.
|
||||
// ALGO: We use a greedy algorithm, where we take values in the 4x4
|
||||
// tile in descending order. If a value fits (because the line/col is not
|
||||
// already full), we select it. Then we move on to the next one.
|
||||
// NOTE: This algorithm might select LESS than 8 values in total in some cases.
|
||||
// NOTE (2): RF are not indexable, so we shouldn't rely on indexing
|
||||
// values at any point, otherwise they will be stored in local memory.
|
||||
template <typename Op = IdentityOp>
|
||||
struct LargestValuesGreedy {
|
||||
template <typename T>
|
||||
static CUTLASS_DEVICE T outOfBoundsFillValue() {
|
||||
return -platform::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
template <typename Tile4x4Accessor>
|
||||
CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) {
|
||||
using TileValueOrdered =
|
||||
TileValueOrderedT<typename Tile4x4Accessor::Element, Op>;
|
||||
using TileValuesFragment = cutlass::Array<TileValueOrdered, 4 * 4>;
|
||||
Indices4x4 indices;
|
||||
TileValuesFragment values_ordered;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
TileValueOrdered& v = values_ordered[i * 4 + j];
|
||||
v.parts.value = values.at(i, j).get();
|
||||
v.parts.col = j;
|
||||
v.parts.row = i;
|
||||
}
|
||||
}
|
||||
// Use a sorting network (aka without branches) to avoid
|
||||
// warp divergence
|
||||
StaticSort<TileValuesFragment::kElements> sorter;
|
||||
sorter(values_ordered);
|
||||
|
||||
// bitmask to store how many we have selected on a given row/col
|
||||
// 0 selected: (numPerRow >> 2*row) = 00 (0)
|
||||
// 1 selected: (numPerRow >> 2*row) = 01 (1)
|
||||
// 2 selected: (numPerRow >> 2*row) = 11 (3)
|
||||
uint32_t numPerRow = 0;
|
||||
uint32_t numPerCol = 0;
|
||||
indices = 0;
|
||||
|
||||
// Take as many as we can, starting with the largest values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = values_ordered.size() - 1; i >= 0; i--) {
|
||||
auto& e = values_ordered[i];
|
||||
|
||||
uint32_t rcount = uint2b_t(numPerRow >> 2 * e.parts.row);
|
||||
uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col);
|
||||
// NOTE: This is more efficient (yet equivalent) to:
|
||||
// `rcount != 3 && ccount != 3`
|
||||
bool selected = (rcount + ccount) <= 2;
|
||||
indices |= selected << (e.parts.col + 4 * e.parts.row);
|
||||
|
||||
numPerRow |= (rcount + selected) << 2 * e.parts.row;
|
||||
numPerCol |= (ccount + selected) << 2 * e.parts.col;
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
};
|
||||
|
||||
// We consider each rows independantly in order
|
||||
// This is to ensure that a row's sparsity pattern is only determined
|
||||
// by its values and the rows before (but never the rows after)
|
||||
// This enforces causality strictly
|
||||
template <typename Op = IdentityOp>
|
||||
struct Causal1122 {
|
||||
template <typename T>
|
||||
static CUTLASS_DEVICE T outOfBoundsFillValue() {
|
||||
return -platform::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
template <typename Tile4x4Accessor>
|
||||
CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) {
|
||||
static constexpr int kMaxValuesPerRow[] = {1, 1, 2, 2};
|
||||
using TileValueOrdered =
|
||||
TileValueOrderedT<typename Tile4x4Accessor::Element, Op>;
|
||||
using TileValuesFragment = cutlass::Array<TileValueOrdered, 4>;
|
||||
Indices4x4 indices = 0;
|
||||
|
||||
uint32_t numPerCol = 0; // <- see doc in `LargestValuesGreedy`
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < 4; ++row) {
|
||||
int row_count = 0;
|
||||
TileValuesFragment values_ordered;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int col = 0; col < 4; ++col) {
|
||||
TileValueOrdered& v = values_ordered[col];
|
||||
v.parts.value = values.at(row, col).get();
|
||||
v.parts.col = col;
|
||||
}
|
||||
// Use a sorting network (aka without branches) to avoid
|
||||
// warp divergence
|
||||
StaticSort<TileValuesFragment::kElements> sorter;
|
||||
sorter(values_ordered);
|
||||
|
||||
// Take as many as we can, starting with the largest values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = values_ordered.size() - 1; i >= 0; i--) {
|
||||
auto& e = values_ordered[i];
|
||||
|
||||
uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col);
|
||||
bool selected = ccount != 3 && (row_count < kMaxValuesPerRow[row]);
|
||||
indices |= selected << (e.parts.col + 4 * row);
|
||||
numPerCol |= (ccount + selected) << 2 * e.parts.col;
|
||||
row_count += selected;
|
||||
}
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void named_algorithms(T callback) {
|
||||
callback(LargestValuesGreedy<IdentityOp>(), "largest_values_greedy");
|
||||
callback(Causal1122<IdentityOp>(), "causal1122");
|
||||
callback(LargestValuesGreedy<AbsOp>(), "largest_abs_values_greedy");
|
||||
// default one
|
||||
callback(LargestValuesGreedy<IdentityOp>(), "");
|
||||
}
|
||||
|
||||
} // namespace
|
@ -1,186 +0,0 @@
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
#else
|
||||
#include <ATen/native/sparse/cuda/ComputeSparseTile.h>
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
#else
|
||||
struct Params {
|
||||
uint64_t const* threads_masks;
|
||||
|
||||
uint16_t const* input;
|
||||
int64_t input_stride;
|
||||
int64_t input_dim0;
|
||||
int64_t input_dim1;
|
||||
|
||||
uint16_t* output;
|
||||
int64_t output_stride;
|
||||
|
||||
__host__ dim3 getBlocksGrid() const {
|
||||
return dim3(
|
||||
cutlass::ceil_div(input_dim0, kWarpX),
|
||||
cutlass::ceil_div(input_dim1, kWarpY),
|
||||
1);
|
||||
}
|
||||
|
||||
static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
|
||||
return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
|
||||
Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
|
||||
gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
|
||||
int64_t strideX = gridDim.y * getThreadsGrid().y;
|
||||
gmem_threads_masks +=
|
||||
(blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
|
||||
return gmem_threads_masks;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool kInputRowMajor = true, bool kOutputRowMajor = true>
|
||||
__global__ void __launch_bounds__(32 /* num_threads */, 32) sparse_semi_structured_apply_dense_k(Params p) {
|
||||
using Fragment = cutlass::Array<uint16_t, 8>;
|
||||
|
||||
// Top-left of the 8x8 tile we own
|
||||
int warp_x = blockIdx.x * kWarpX;
|
||||
int warp_y = blockIdx.y * kWarpY;
|
||||
int x = warp_x + threadIdx.x * kThreadX;
|
||||
int y = warp_y + threadIdx.y * kThreadY;
|
||||
|
||||
uint16_t* output = p.output + x * p.output_stride + y;
|
||||
Tile8x8Masks indices = *p.getCurrentThreadIndices();
|
||||
|
||||
// Load dense
|
||||
Fragment lines[8];
|
||||
if (kInputRowMajor) {
|
||||
uint16_t const* input = p.input + x * p.input_stride + y;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_stride, true);
|
||||
}
|
||||
} else {
|
||||
uint16_t const* input = p.input + x + y * p.input_stride;
|
||||
Fragment columns[8];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
columns[i], input + i * p.input_stride, true);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
lines[i][j] = columns[j][i].get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
Indices4x4 masks[2];
|
||||
if (row == 0) {
|
||||
masks[0] = indices.a;
|
||||
masks[1] = indices.b;
|
||||
} else {
|
||||
masks[0] = indices.c;
|
||||
masks[1] = indices.d;
|
||||
}
|
||||
|
||||
// Apply mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < 2; ++m) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int r = 0; r < 4; ++r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < 4; ++c) {
|
||||
lines[4 * row + r][4 * m + c] = lines[4 * row + r][4 * m + c] *
|
||||
int((masks[m] >> (4 * r + c)) & 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static_assert(kOutputRowMajor, "Transpose here for ColMajor output");
|
||||
// Save dense with zeros
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
cutlass::arch::global_store<Fragment, sizeof(Fragment)>(
|
||||
lines[i], output + i * p.output_stride, true);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Tensor _sparse_semi_structured_apply_dense(
|
||||
const Tensor& input,
|
||||
const Tensor& threads_masks) {
|
||||
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
AT_ERROR("_sparse_semi_structured_apply_dense: not supported");
|
||||
return Tensor{};
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported `input` dtype");
|
||||
TORCH_CHECK(
|
||||
input.stride(0) == 1 || input.stride(1) == 1,
|
||||
"`input` should be either RowMajor or ColMajor. Invalid memory layout - try .contiguous()?");
|
||||
|
||||
auto roundedx = cutlass::round_up(input.size(0), kWarpX);
|
||||
auto roundedy = cutlass::round_up(input.size(1), kWarpY);
|
||||
|
||||
Params p;
|
||||
p.input = (uint16_t const*)input.data_ptr();
|
||||
p.input_dim0 = input.size(0);
|
||||
p.input_dim1 = input.size(1);
|
||||
p.threads_masks = (uint64_t const*)threads_masks.data_ptr();
|
||||
|
||||
TORCH_CHECK(threads_masks.dim() == 3);
|
||||
TORCH_CHECK(threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
|
||||
TORCH_CHECK(threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
|
||||
TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.stride(2) == 1);
|
||||
TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
|
||||
|
||||
at::Tensor output = at::empty({p.input_dim0, p.input_dim1}, input.options());
|
||||
TORCH_INTERNAL_ASSERT(output.stride(-1) == 1, "expected RowMajor?");
|
||||
p.output = (uint16_t*)output.data_ptr();
|
||||
|
||||
bool inputRowMajor = input.stride(-1) == 1;
|
||||
bool outputRowMajor = output.stride(-1) == 1;
|
||||
p.input_stride = input.stride(inputRowMajor ? 0 : 1);
|
||||
p.output_stride = output.stride(outputRowMajor ? 0 : 1);
|
||||
at::cuda::CUDAGuard device_guard(input.device());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
size_t smem_bytes = 0;
|
||||
if (inputRowMajor && outputRowMajor) {
|
||||
sparse_semi_structured_apply_dense_k<true, true>
|
||||
<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
|
||||
} else if (!inputRowMajor && outputRowMajor) {
|
||||
sparse_semi_structured_apply_dense_k<false, true>
|
||||
<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported configuration: `input` is ",
|
||||
inputRowMajor ? "RowMajor" : "ColMajor",
|
||||
", and `output` is ",
|
||||
outputRowMajor ? "RowMajor" : "ColMajor");
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
return output;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
@ -1,520 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/sparse/cuda/StaticSort.h>
|
||||
#include <cutlass/arch/memory.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/bfloat16.h>
|
||||
#include <cutlass/fast_math.h>
|
||||
#include <cutlass/half.h>
|
||||
#include <cutlass/integer_subbyte.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using cutlass::uint1b_t;
|
||||
using cutlass::uint2b_t;
|
||||
using cutlass::uint4b_t;
|
||||
using uint8b_t = cutlass::integer_subbyte<8, false>;
|
||||
using ReorderedLayoutInputE = cutlass::layout::ColumnMajorInterleaved<2>;
|
||||
using ElementInputE = uint16_t;
|
||||
constexpr int kWarpX = 32;
|
||||
constexpr int kWarpY = 64;
|
||||
constexpr int kThreadX = 8;
|
||||
constexpr int kThreadY = 8;
|
||||
|
||||
// bitmask of selected values, in col-major storage
|
||||
// eg: indices & (1 << (col + 4 * row))
|
||||
using Indices4x4 = uint16_t;
|
||||
|
||||
struct Tile8x8Masks {
|
||||
Indices4x4 a, b, c, d;
|
||||
CUTLASS_DEVICE Tile8x8Masks() {
|
||||
a = b = c = d = 0;
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(sizeof(Tile8x8Masks) == 8, "should be exactly uint64_t");
|
||||
|
||||
// Each thread has data for an 8x8 area of the input tensor
|
||||
// Due to the very specific format of the metadata, 32 consecutive bits
|
||||
// of the metadata tensor will live in 4 different threads.
|
||||
// This functions does the required warp shuffling to send data to the
|
||||
// right threads.
|
||||
// This took some time to write (and get right), hopefully these slides
|
||||
// can help
|
||||
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g249eb2e2f2e_0_28
|
||||
CUTLASS_DEVICE uint32_t
|
||||
warp_shuffle_meta(uint32_t meta_ab, bool transposed = false) {
|
||||
// The required format is
|
||||
// (one line = 32 bits)
|
||||
// a[ 0, 0:16] a[ 8, 0:16] <- T0 [left]
|
||||
// a[ 0, 16:32] a[ 8, 16:32]
|
||||
// a[16, 0:16] a[24, 0:16]
|
||||
// a[16, 16:32] a[24, 16:32]
|
||||
// a[ 1, 0:16] a[ 9, 0:16] <- T4
|
||||
// a[ 1, 16:32] a[ 9, 16:32]
|
||||
// a[17, 0:16] a[25, 0:16]
|
||||
// a[17, 16:32] a[25, 16:32]
|
||||
// a[ 2, 0:16] a[10, 0:16] <- T1 [left, bottom]
|
||||
// a[ 2, 16:32] a[10, 16:32]
|
||||
// a[18, 0:16] a[26, 0:16]
|
||||
// a[18, 16:32] a[26, 16:32]
|
||||
// a[ 3, 0:16] a[11, 0:16] <- T5 [bottom]
|
||||
// a[ 3, 16:32] a[11, 16:32]
|
||||
// a[19, 0:16] a[27, 0:16]
|
||||
// a[19, 16:32] a[27, 16:32]
|
||||
// ...
|
||||
// Use warp-shuffles to send data around threads
|
||||
bool thread_left = (threadIdx.y % 2) == 0;
|
||||
bool thread_bottom = threadIdx.x % 2;
|
||||
|
||||
if (transposed) {
|
||||
thread_left = (threadIdx.x % 2) == 0;
|
||||
thread_bottom = threadIdx.y % 2;
|
||||
}
|
||||
|
||||
uint8b_t stage0_data[2] = {
|
||||
uint8b_t(meta_ab >> (8 * thread_left)),
|
||||
uint8b_t(meta_ab >> (8 * (thread_left + 2)))};
|
||||
// shfl t0-t4 / t1-t5
|
||||
stage0_data[0] =
|
||||
__shfl_xor_sync(0xffffffff, stage0_data[0], transposed ? 1 : 4);
|
||||
stage0_data[1] =
|
||||
__shfl_xor_sync(0xffffffff, stage0_data[1], transposed ? 1 : 4);
|
||||
|
||||
uint16_t line0 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left))))
|
||||
<< ((1 - thread_left) * 8);
|
||||
line0 |= int(stage0_data[0]) << (thread_left * 8);
|
||||
uint16_t line1 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left + 2))))
|
||||
<< ((1 - thread_left) * 8);
|
||||
line1 |= int(stage0_data[1]) << (thread_left * 8);
|
||||
|
||||
uint16_t stage1_data = thread_bottom ? line0 : line1;
|
||||
stage1_data = __shfl_xor_sync(0xffffffff, stage1_data, transposed ? 4 : 1);
|
||||
|
||||
uint32_t final_metadata;
|
||||
if (thread_bottom) {
|
||||
final_metadata = uint32_t(stage1_data) | uint32_t(line1) << 16;
|
||||
} else {
|
||||
final_metadata = uint32_t(stage1_data) << 16 | uint32_t(line0);
|
||||
}
|
||||
return final_metadata;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void warp_shuffle_and_write_meta(
|
||||
ElementInputE* metadata_quad,
|
||||
uint32_t meta_ab,
|
||||
bool transposed = false) {
|
||||
bool thread_left = (threadIdx.y % 2) == 0;
|
||||
bool thread_bottom = threadIdx.x % 2;
|
||||
|
||||
if (transposed) {
|
||||
thread_left = (threadIdx.x % 2) == 0;
|
||||
thread_bottom = threadIdx.y % 2;
|
||||
}
|
||||
|
||||
uint32_t final_metadata = warp_shuffle_meta(meta_ab, transposed);
|
||||
|
||||
int index = (!thread_left + 2 * thread_bottom) * 4;
|
||||
((uint32_t*)metadata_quad)[index] = final_metadata;
|
||||
}
|
||||
|
||||
template <typename Element_>
|
||||
struct KernelTypes {
|
||||
using Element = Element_;
|
||||
using Fragment =
|
||||
cutlass::Array<Element, 8>; // always read from gmem in chunks of 128bits
|
||||
using Fragment4 = cutlass::Array<Element, 4>;
|
||||
using ValuesPacked = cutlass::Array<Element, 8>; // 4 first col, 4 second col
|
||||
|
||||
struct Params {
|
||||
/// inputs
|
||||
Element const* input;
|
||||
int64_t input_s0;
|
||||
int64_t input_dim0;
|
||||
int64_t input_dim1;
|
||||
|
||||
/// outputs
|
||||
Element* packed;
|
||||
int64_t packed_stride;
|
||||
|
||||
Element* packed_trans;
|
||||
int64_t packed_trans_stride;
|
||||
|
||||
uint64_t* threads_masks;
|
||||
|
||||
__host__ dim3 getBlocksGrid() const {
|
||||
return dim3(
|
||||
cutlass::ceil_div(input_dim0, kWarpX),
|
||||
cutlass::ceil_div(input_dim1, kWarpY),
|
||||
1);
|
||||
}
|
||||
|
||||
static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
|
||||
return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
|
||||
Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
|
||||
gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
|
||||
int64_t strideX = gridDim.y * getThreadsGrid().y;
|
||||
gmem_threads_masks +=
|
||||
(blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
|
||||
return gmem_threads_masks;
|
||||
}
|
||||
};
|
||||
|
||||
struct Tile4x4Accessor {
|
||||
using Element = Element_;
|
||||
|
||||
Fragment (&_lines)[8];
|
||||
int _start_row;
|
||||
int _start_col;
|
||||
|
||||
CUTLASS_DEVICE Tile4x4Accessor(
|
||||
Fragment (&lines)[8],
|
||||
int start_row,
|
||||
int start_col)
|
||||
: _lines(lines), _start_row(start_row), _start_col(start_col) {}
|
||||
|
||||
CUTLASS_DEVICE typename Fragment::reference at(int r, int c) {
|
||||
return _lines[r + _start_row][c + _start_col];
|
||||
}
|
||||
};
|
||||
|
||||
struct Tile4x4Packed {
|
||||
Fragment4 values[2];
|
||||
CUTLASS_DEVICE Tile4x4Packed() {
|
||||
values[0].clear();
|
||||
values[1].clear();
|
||||
}
|
||||
};
|
||||
|
||||
// Returns a packed 4x4 tile (eg 2x4 values) which correspond to the values
|
||||
// that are in `indices`. Also fills the `meta` array in the right format
|
||||
// for consumption in the TensorCores.
|
||||
// Example:
|
||||
// indices: 0011
|
||||
// 1001
|
||||
// 1001
|
||||
// 0100 (<- note, only 1 value on the last line)
|
||||
// packed: values[0][2] values[1][0] values[2][0] values[3][1]
|
||||
// values[0][3] values[1][3] values[2][3] Element(0)
|
||||
CUTLASS_DEVICE static Tile4x4Packed pack_4x4(
|
||||
Indices4x4 indices,
|
||||
Tile4x4Accessor tile,
|
||||
uint32_t& meta,
|
||||
int meta_pos,
|
||||
bool transpose = false) {
|
||||
Tile4x4Packed packed;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < 4; ++row) {
|
||||
uint2b_t col0_from, col1_from;
|
||||
auto packValue = [&](uint2b_t col_to, uint2b_t col_from) {
|
||||
auto value = transpose ? tile.at(col_from, row).get()
|
||||
: tile.at(row, col_from).get();
|
||||
packed.values[col_to][row] = value;
|
||||
if (col_to == uint2b_t(0)) {
|
||||
col0_from = col_from;
|
||||
} else {
|
||||
col1_from = col_from;
|
||||
}
|
||||
};
|
||||
auto isSelected = [&](int col) {
|
||||
if (transpose) {
|
||||
return indices & (1 << (row + 4 * col));
|
||||
}
|
||||
return indices & (1 << (col + 4 * row));
|
||||
};
|
||||
// Process cols 0/1
|
||||
// We know that col0 is always packed to position 0 if it's there
|
||||
// and col1 is packed to pos 0 or 1 (depending if col0 is selected)
|
||||
if (isSelected(1)) {
|
||||
packValue(0, 1);
|
||||
}
|
||||
if (isSelected(0)) {
|
||||
packValue(0, 0);
|
||||
}
|
||||
if (isSelected(0) && isSelected(1)) {
|
||||
packValue(1, 1);
|
||||
}
|
||||
// Process cols 2/3
|
||||
// same sort of heuristic
|
||||
if (isSelected(2)) {
|
||||
packValue(1, 2);
|
||||
}
|
||||
if (isSelected(3)) {
|
||||
packValue(1, 3);
|
||||
}
|
||||
if (isSelected(2) && isSelected(3)) {
|
||||
packValue(0, 2);
|
||||
}
|
||||
int add_mask = (col0_from | (col1_from << 2)) << (8 * row + meta_pos);
|
||||
meta |= add_mask;
|
||||
}
|
||||
return packed;
|
||||
}
|
||||
|
||||
struct Tile8x8Meta {
|
||||
// meta_ab[row] |= (real_col << (8*row + 2*pos))
|
||||
uint32_t meta_ab;
|
||||
uint32_t meta_cd;
|
||||
|
||||
// meta_ac_trans[col] |= (real_row << (8*col + 2*pos))
|
||||
uint32_t meta_ac_trans;
|
||||
uint32_t meta_bd_trans;
|
||||
|
||||
CUTLASS_DEVICE Tile8x8Meta() {
|
||||
meta_ab = meta_cd = meta_ac_trans = meta_bd_trans = 0;
|
||||
}
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE static void writePacked(
|
||||
Element* ptr,
|
||||
Fragment4 packed0,
|
||||
Fragment4 packed1) {
|
||||
Fragment write;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
write[i] = packed0[i].get();
|
||||
write[i + 4] = packed1[i].get();
|
||||
}
|
||||
cutlass::arch::global_store<Fragment, sizeof(Fragment)>(write, ptr, true);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE static void writePackedT(
|
||||
Element* ptr,
|
||||
int64_t stride,
|
||||
Tile4x4Packed a,
|
||||
Tile4x4Packed b) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
Fragment4 write;
|
||||
write[0] = a.values[0][i].get();
|
||||
write[1] = a.values[1][i].get();
|
||||
write[2] = b.values[0][i].get();
|
||||
write[3] = b.values[1][i].get();
|
||||
cutlass::arch::global_store<Fragment4, sizeof(Fragment4)>(
|
||||
write, ptr + i * stride, true);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Algorithm, typename MetadataStore>
|
||||
CUTLASS_DEVICE static void sparse_semi_structured_tile_kernel(
|
||||
Params p,
|
||||
MetadataStore metadata_gmem,
|
||||
Algorithm compute_tile_indices) {
|
||||
// Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
|
||||
// A, B, C and D, as displayed in the following schema:
|
||||
// +---+---+
|
||||
// | A | B |
|
||||
// +---+---+
|
||||
// | C | D |
|
||||
// +---+---+
|
||||
// Each warp (32 threads) will then be responsible for a 32x64 tile of the
|
||||
// input.
|
||||
// This configuration allows to read/write data in 128bits chunks. These
|
||||
// memory accesses are coalesced at the warp-level into 128bytes. See also:
|
||||
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g2494f30c7cf_0_0
|
||||
|
||||
// Top-left of the 8x8 tile we own
|
||||
int warp_x = blockIdx.x * kWarpX;
|
||||
int warp_y = blockIdx.y * kWarpY;
|
||||
int x = warp_x + threadIdx.x * kThreadX;
|
||||
int y = warp_y + threadIdx.y * kThreadY;
|
||||
|
||||
Element const* input = p.input + x * p.input_s0 + y;
|
||||
Element* packed = p.packed + x * p.packed_stride + (y / 2);
|
||||
Element* packed_trans =
|
||||
p.packed_trans + (x / 2) + y * p.packed_trans_stride;
|
||||
|
||||
Fragment lines[8]; // Contains all values from the 8x8 tile
|
||||
|
||||
Tile8x8Meta metadata;
|
||||
Tile8x8Masks indices;
|
||||
|
||||
// Load/process tiles `A` and `B`
|
||||
Element fillValue = Algorithm::template outOfBoundsFillValue<Element>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
lines[i].fill(fillValue);
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
|
||||
}
|
||||
indices.a = compute_tile_indices(Tile4x4Accessor(lines, 0, 0));
|
||||
indices.b = compute_tile_indices(Tile4x4Accessor(lines, 0, 4));
|
||||
|
||||
// Compute packed tiles A & B
|
||||
{
|
||||
Tile4x4Packed packed_a = pack_4x4(
|
||||
indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
|
||||
Tile4x4Packed packed_b = pack_4x4(
|
||||
indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
|
||||
writePackedT(packed, p.packed_stride, packed_a, packed_b);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles A & B in transpose output
|
||||
Tile4x4Packed packed_trans_a = pack_4x4(
|
||||
indices.a,
|
||||
Tile4x4Accessor(lines, 0, 0),
|
||||
metadata.meta_ac_trans,
|
||||
0,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_b = pack_4x4(
|
||||
indices.b,
|
||||
Tile4x4Accessor(lines, 0, 4),
|
||||
metadata.meta_bd_trans,
|
||||
0,
|
||||
true);
|
||||
// (NOTE) Now we no longer need A & B (`lines[0:4]`)
|
||||
|
||||
// Load/process tiles `C` and `D`
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 4; i < 8; ++i) {
|
||||
lines[i].fill(fillValue);
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
|
||||
}
|
||||
indices.c = compute_tile_indices(Tile4x4Accessor(lines, 4, 0));
|
||||
indices.d = compute_tile_indices(Tile4x4Accessor(lines, 4, 4));
|
||||
|
||||
// Compute packed tiles C & D
|
||||
{
|
||||
Tile4x4Packed packed_c = pack_4x4(
|
||||
indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
|
||||
Tile4x4Packed packed_d = pack_4x4(
|
||||
indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
|
||||
writePackedT(
|
||||
packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles C & D in transpose output
|
||||
Tile4x4Packed packed_trans_c = pack_4x4(
|
||||
indices.c,
|
||||
Tile4x4Accessor(lines, 4, 0),
|
||||
metadata.meta_ac_trans,
|
||||
4,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_d = pack_4x4(
|
||||
indices.d,
|
||||
Tile4x4Accessor(lines, 4, 4),
|
||||
metadata.meta_bd_trans,
|
||||
4,
|
||||
true);
|
||||
|
||||
// Dump the metadata in a nice format
|
||||
*p.getCurrentThreadIndices() = indices;
|
||||
|
||||
// Store packed A, B, C & D for transposed matrix
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
|
||||
packed_trans += 4 * p.packed_trans_stride;
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
|
||||
|
||||
// Writing meta non-transposed
|
||||
{
|
||||
ElementInputE* packed_meta_reordered = metadata_gmem.get_metaN(
|
||||
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
|
||||
warp_shuffle_and_write_meta(packed_meta_reordered, metadata.meta_ab);
|
||||
warp_shuffle_and_write_meta(packed_meta_reordered + 32, metadata.meta_cd);
|
||||
}
|
||||
|
||||
// Writing meta transposed
|
||||
{
|
||||
ElementInputE* packed_trans_meta_reordered = metadata_gmem.get_metaT(
|
||||
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
|
||||
warp_shuffle_and_write_meta(
|
||||
packed_trans_meta_reordered, metadata.meta_ac_trans, true);
|
||||
warp_shuffle_and_write_meta(
|
||||
packed_trans_meta_reordered + 32, metadata.meta_bd_trans, true);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE static void sparse_semi_structured_apply_kernel(Params p) {
|
||||
// See `sparse24_sparsify_both_ways_kernel`
|
||||
// It's basically the same, just that we skip
|
||||
// the part where compute the indices we keep
|
||||
|
||||
// Top-left of the 8x8 tile we own
|
||||
int warp_x = blockIdx.x * kWarpX;
|
||||
int warp_y = blockIdx.y * kWarpY;
|
||||
int x = warp_x + threadIdx.x * kThreadX;
|
||||
int y = warp_y + threadIdx.y * kThreadY;
|
||||
|
||||
Element const* input = p.input + x * p.input_s0 + y;
|
||||
Element* packed = p.packed + x * p.packed_stride + (y / 2);
|
||||
Element* packed_trans =
|
||||
p.packed_trans + (x / 2) + y * p.packed_trans_stride;
|
||||
|
||||
Fragment lines[8]; // Contains all values from the 8x8 tile
|
||||
|
||||
Tile8x8Meta metadata;
|
||||
Tile8x8Masks indices = *p.getCurrentThreadIndices();
|
||||
|
||||
// Load/process tiles `A` and `B`
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
// NB: Values outside bounds is undefined, but shouldn't
|
||||
// be used anywhere
|
||||
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
|
||||
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
|
||||
}
|
||||
|
||||
// Compute packed tiles A & B
|
||||
{
|
||||
Tile4x4Packed packed_a = pack_4x4(
|
||||
indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
|
||||
Tile4x4Packed packed_b = pack_4x4(
|
||||
indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
|
||||
writePackedT(packed, p.packed_stride, packed_a, packed_b);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles A & B in transpose output
|
||||
Tile4x4Packed packed_trans_a = pack_4x4(
|
||||
indices.a,
|
||||
Tile4x4Accessor(lines, 0, 0),
|
||||
metadata.meta_ac_trans,
|
||||
0,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_b = pack_4x4(
|
||||
indices.b,
|
||||
Tile4x4Accessor(lines, 0, 4),
|
||||
metadata.meta_bd_trans,
|
||||
0,
|
||||
true);
|
||||
// (NOTE) Now we no longer need A & B (`lines[0:4]`)
|
||||
|
||||
// Compute packed tiles C & D
|
||||
{
|
||||
Tile4x4Packed packed_c = pack_4x4(
|
||||
indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
|
||||
Tile4x4Packed packed_d = pack_4x4(
|
||||
indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
|
||||
writePackedT(
|
||||
packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
|
||||
}
|
||||
|
||||
// Compute/store packed tiles C & D in transpose output
|
||||
Tile4x4Packed packed_trans_c = pack_4x4(
|
||||
indices.c,
|
||||
Tile4x4Accessor(lines, 4, 0),
|
||||
metadata.meta_ac_trans,
|
||||
4,
|
||||
true);
|
||||
Tile4x4Packed packed_trans_d = pack_4x4(
|
||||
indices.d,
|
||||
Tile4x4Accessor(lines, 4, 4),
|
||||
metadata.meta_bd_trans,
|
||||
4,
|
||||
true);
|
||||
|
||||
// Store packed A, B, C & D for transposed matrix
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
|
||||
packed_trans += 4 * p.packed_trans_stride;
|
||||
writePackedT(
|
||||
packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at::native
|
@ -1,312 +0,0 @@
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAUtils.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
#else
|
||||
#include <ATen/native/sparse/cuda/ComputeSparseTile.h>
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
#else
|
||||
struct MetadataCuSparseLt {
|
||||
// Format used by cuSparseLt
|
||||
// This is based on reverse-engineering, for a visual illustration:
|
||||
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g29afe95bda8_0_0
|
||||
static constexpr int kStrideBlock32x32 = (32 * 32) / (sizeof(ElementInputE) * 8);
|
||||
|
||||
ElementInputE* _meta;
|
||||
ElementInputE* _meta_trans;
|
||||
int64_t _rows;
|
||||
int64_t _cols;
|
||||
|
||||
static int64_t getMetadataSize(int rows, int cols)
|
||||
{
|
||||
TORCH_CHECK(rows % 128 == 0 && cols % 128 == 0, "Only supports rows/cols multiples of 128");
|
||||
// 1 bit per dense value
|
||||
return (rows * cols) / (8 * sizeof(ElementInputE));
|
||||
}
|
||||
|
||||
// < return value of the function, packed, packed_meta >
|
||||
static std::tuple<Tensor, Tensor, Tensor> create_compressed_representation(int rows, int cols, at::Tensor const& like)
|
||||
{
|
||||
TORCH_CHECK(
|
||||
like.scalar_type() == at::ScalarType::Half ||
|
||||
like.scalar_type() == at::ScalarType::BFloat16);
|
||||
constexpr int kBytesPerScalar = 2;
|
||||
int64_t data_scalars = rows * cutlass::ceil_div(cols, 2);
|
||||
int64_t meta_scalars = getMetadataSize(rows, cols);
|
||||
|
||||
at::Tensor storage = at::empty(
|
||||
{(data_scalars + meta_scalars)},
|
||||
at::TensorOptions().device(like.device()).dtype(like.dtype()));
|
||||
|
||||
using at::indexing::Slice;
|
||||
using at::indexing::None;
|
||||
at::Tensor packed = storage.index({Slice(None, data_scalars)})
|
||||
.view({rows, cutlass::ceil_div(cols, 2)});
|
||||
at::Tensor metadata = storage.index({Slice(data_scalars, None)});
|
||||
// TODO: Cast metadata to Short
|
||||
static_assert(kBytesPerScalar == 2, "or modify the last dim below");
|
||||
metadata = metadata.view({rows / 128, cols / 32, 256});
|
||||
return std::make_tuple(storage, packed, metadata);
|
||||
}
|
||||
|
||||
MetadataCuSparseLt(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
|
||||
_meta = (ElementInputE*)metaN.data_ptr();
|
||||
_meta_trans = (ElementInputE*)metaT.data_ptr();
|
||||
_rows = rows;
|
||||
_cols = cols;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int64_t _get_meta_offset(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col,
|
||||
int totalRows) {
|
||||
int64_t offset = 0;
|
||||
// warp-level: Find the 128x64 tile
|
||||
offset += (warp_row / 128) * (kStrideBlock32x32 * 8);
|
||||
offset += (warp_col / 64) * (kStrideBlock32x32 * 8) * (totalRows / 128);
|
||||
// Find the 32x32 tile inside
|
||||
offset += (((warp_row + thread_row) % 128) / 32) * kStrideBlock32x32;
|
||||
offset += (((warp_col + thread_col) % 64) / 32) * (kStrideBlock32x32 * 4);
|
||||
// Inside the 32x32 tile
|
||||
offset += (warp_row % 32) * 2;
|
||||
// Top/bottom 16x16 tile
|
||||
offset += ((thread_row % 32) / 16) * 4;
|
||||
// Left/right 16x16 tile
|
||||
offset += ((thread_col % 32) / 16) * 2;
|
||||
return offset;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaN(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta +
|
||||
_get_meta_offset(warp_row, thread_row, warp_col, thread_col, _rows);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaT(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta_trans +
|
||||
_get_meta_offset(warp_col, thread_col, warp_row, thread_row, _cols);
|
||||
}
|
||||
};
|
||||
|
||||
struct MetadataCutlass {
|
||||
// Layout needed to run 2:4 gemms in CUTLASS
|
||||
// There is basically a hardware specific value for every
|
||||
// 32x32 dense tile (1024 bits). Then these tiles are
|
||||
// stored in a Column-Major fashion
|
||||
ElementInputE* _meta;
|
||||
ElementInputE* _meta_trans;
|
||||
int64_t _meta_reordered_sy;
|
||||
int64_t _meta_trans_reordered_sx;
|
||||
|
||||
static std::tuple<
|
||||
at::Tensor, // return value of the function
|
||||
at::Tensor, // packed
|
||||
at::Tensor // packed_meta
|
||||
>
|
||||
create_compressed_representation(int rows, int cols, at::Tensor const& like) {
|
||||
TORCH_CHECK(
|
||||
like.scalar_type() == at::ScalarType::Half ||
|
||||
like.scalar_type() == at::ScalarType::BFloat16);
|
||||
auto roundedx = cutlass::round_up(rows, kWarpX);
|
||||
auto roundedy = cutlass::round_up(cols, kWarpY);
|
||||
|
||||
// NB: Writing to `packed` tensors in transposed manner
|
||||
at::Tensor packed =
|
||||
at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, like.options());
|
||||
at::Tensor packed_meta = at::empty(
|
||||
{roundedx * roundedy / 16},
|
||||
like.options().dtype(at::ScalarType::Short))
|
||||
.view({roundedy / 32, roundedx, 2})
|
||||
.permute({1, 2, 0});
|
||||
return std::make_tuple(packed, packed, packed_meta);
|
||||
}
|
||||
MetadataCutlass(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
|
||||
_meta = (ElementInputE*)metaN.data_ptr();
|
||||
_meta_reordered_sy = metaN.stride(2);
|
||||
_meta_trans = (ElementInputE*)metaT.data_ptr();
|
||||
_meta_trans_reordered_sx = metaT.stride(2);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t _get_meta_offset(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col,
|
||||
int64_t stride) const {
|
||||
int64_t offset = 0;
|
||||
offset += warp_row * 2 + (warp_col / 32) * stride;
|
||||
// A single warp is 32x64. The right 32x32 tile is at a different position
|
||||
offset += 64 * (thread_row / 32);
|
||||
offset += (thread_col / 32) * stride;
|
||||
// Top/bottom 16x16 tile
|
||||
offset += ((thread_row % 32) / 16) * 4;
|
||||
// Left/right 16x16 tile
|
||||
offset += ((thread_col % 32) / 16) * 2;
|
||||
return offset;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaN(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta +
|
||||
_get_meta_offset(
|
||||
warp_row, thread_row, warp_col, thread_col, _meta_reordered_sy);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
ElementInputE* get_metaT(
|
||||
int warp_row,
|
||||
int thread_row,
|
||||
int warp_col,
|
||||
int thread_col) const {
|
||||
return _meta_trans +
|
||||
_get_meta_offset(
|
||||
warp_col,
|
||||
thread_col,
|
||||
warp_row,
|
||||
thread_row,
|
||||
_meta_trans_reordered_sx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename KT, typename Metadata, typename Algorithm>
|
||||
__global__ void __launch_bounds__(32 /* num_threads */, 20)
|
||||
sparse_semi_structured_tile_kernel(
|
||||
typename KT::Params p,
|
||||
Metadata metadata,
|
||||
Algorithm algo) {
|
||||
KT::sparse_semi_structured_tile_kernel(p, metadata, algo);
|
||||
}
|
||||
|
||||
template <typename Element, typename MetadataFormat>
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> sparse_semi_structured_tile_typed(
|
||||
const at::Tensor input,
|
||||
std::string algorithm)
|
||||
{
|
||||
using KT = KernelTypes<Element>;
|
||||
c10::optional<at::cuda::CUDAGuard> device_guard;
|
||||
if (!input.is_meta()) {
|
||||
device_guard.emplace(input.device());
|
||||
}
|
||||
|
||||
TORCH_CHECK(input.dim() == 2, "Can only sparsify 2d tensors");
|
||||
TORCH_CHECK(
|
||||
input.stride(1) == 1,
|
||||
"Can only sparsify contiguous tensors. Sparsify the transpose otherwise.");
|
||||
|
||||
auto rows = input.size(0);
|
||||
auto cols = input.size(1);
|
||||
|
||||
auto [compressed, packed, packed_meta_reordered] =
|
||||
MetadataFormat::create_compressed_representation(rows, cols, input);
|
||||
auto [compressed_trans, packed_trans, packed_trans_meta_reordered] =
|
||||
MetadataFormat::create_compressed_representation(cols, rows, input);
|
||||
TORCH_CHECK(
|
||||
input.size(1) % 32 == 0, "Number of cols should be multiple of 32");
|
||||
|
||||
typename KT::Params p;
|
||||
p.input = (Element const*)input.data_ptr();
|
||||
p.input_s0 = input.stride(0);
|
||||
p.input_dim0 = input.size(0);
|
||||
p.input_dim1 = input.size(1);
|
||||
|
||||
p.packed = (Element*)packed.data_ptr();
|
||||
p.packed_stride = packed.stride(0);
|
||||
p.packed_trans = (Element*)packed_trans.data_ptr();
|
||||
p.packed_trans_stride = packed_trans.stride(0);
|
||||
|
||||
MetadataFormat metadata = MetadataFormat(
|
||||
packed_meta_reordered, packed_trans_meta_reordered, rows, cols);
|
||||
at::Tensor threads_masks = at::empty(
|
||||
{p.getBlocksGrid().x * p.getThreadsGrid().x,
|
||||
p.getBlocksGrid().y * p.getThreadsGrid().y,
|
||||
sizeof(p.threads_masks[0])},
|
||||
input.options().dtype(at::ScalarType::Byte));
|
||||
p.threads_masks = (uint64_t*)threads_masks.data_ptr();
|
||||
|
||||
bool kernel_launched = false;
|
||||
auto launchKernel = [&](auto algo, std::string const& algo_name) {
|
||||
if (algo_name == algorithm) {
|
||||
kernel_launched = true;
|
||||
if (input.is_meta()) {
|
||||
return;
|
||||
}
|
||||
size_t smem_bytes = 0;
|
||||
sparse_semi_structured_tile_kernel<KT>
|
||||
<<<p.getBlocksGrid(),
|
||||
p.getThreadsGrid(),
|
||||
smem_bytes,
|
||||
at::cuda::getCurrentCUDAStream()>>>(p, metadata, algo);
|
||||
}
|
||||
};
|
||||
named_algorithms(launchKernel);
|
||||
TORCH_CHECK(kernel_launched, "Unknown algorithm \"", algorithm, "\"");
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
return std::make_tuple(
|
||||
compressed,
|
||||
packed_meta_reordered,
|
||||
compressed_trans,
|
||||
packed_trans_meta_reordered,
|
||||
threads_masks);
|
||||
}
|
||||
#endif
|
||||
|
||||
// <packed, packed_meta_reordered, packed_trans, packed_trans_meta_reorderd, threads_masks>
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _sparse_semi_structured_tile(
|
||||
const Tensor& input,
|
||||
c10::string_view algorithm,
|
||||
bool use_cutlass)
|
||||
{
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
AT_ERROR("_sparse_semi_structured_tile: not supported");
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
#else
|
||||
std::string algo(algorithm.data(), algorithm.size());
|
||||
|
||||
auto runTyped = [&](auto type)
|
||||
{
|
||||
using ElementT = decltype(type);
|
||||
if (use_cutlass) {
|
||||
return sparse_semi_structured_tile_typed<ElementT, MetadataCutlass>(input, algo);
|
||||
}
|
||||
else {
|
||||
return sparse_semi_structured_tile_typed<ElementT, MetadataCuSparseLt>(input, algo);
|
||||
}
|
||||
};
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
return runTyped(cutlass::half_t());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::ScalarType::Half ||
|
||||
input.scalar_type() == at::ScalarType::BFloat16, input.scalar_type());
|
||||
return runTyped(cutlass::bfloat16_t());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -1,107 +0,0 @@
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
#else
|
||||
#include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
#else
|
||||
template <typename KT>
|
||||
__global__ void __launch_bounds__(32 /* num_threads */)
|
||||
sparse_semi_structured_apply_kernel(typename KT::Params p)
|
||||
{
|
||||
KT::sparse_semi_structured_apply_kernel(p);
|
||||
}
|
||||
|
||||
// Apply a 2:4 sparsify pattern computed with
|
||||
// `_sparse_semi_structured_tile` to another Tensor
|
||||
template <bool kIsMeta, typename Element>
|
||||
std::tuple<Tensor, Tensor> _sparse_semi_structured_apply_typed(Tensor input, Tensor threads_masks)
|
||||
{
|
||||
using KT = KernelTypes<Element>;
|
||||
// TODO: Technically we should be able to deal with that
|
||||
// by running on the transpose of `input` and swapping
|
||||
// `packed` & `packed_t`.
|
||||
// This would require to adapt the `threads_masks` a bit tho.
|
||||
if (input.stride(1) != 1) {
|
||||
input = input.contiguous();
|
||||
}
|
||||
c10::optional<at::cuda::CUDAGuard> device_guard;
|
||||
if (!kIsMeta) {
|
||||
device_guard.emplace(input.device());
|
||||
}
|
||||
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input.stride(1) == 1);
|
||||
TORCH_CHECK(input.stride(0) % 8 == 0);
|
||||
TORCH_CHECK(input.size(1) % 32 == 0, "Wrong alignment shape[1]");
|
||||
|
||||
auto roundedx = cutlass::round_up(input.size(0), kWarpX);
|
||||
auto roundedy = cutlass::round_up(input.size(1), kWarpY);
|
||||
at::Tensor packed =
|
||||
at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, input.options());
|
||||
at::Tensor packed_trans =
|
||||
at::empty({roundedy, cutlass::ceil_div(roundedx, 2)}, input.options());
|
||||
|
||||
typename KT::Params p;
|
||||
p.input = (Element const*)input.data_ptr();
|
||||
p.input_s0 = input.stride(0);
|
||||
p.input_dim0 = input.size(0);
|
||||
p.input_dim1 = input.size(1);
|
||||
|
||||
p.packed = (Element*)packed.data_ptr();
|
||||
p.packed_stride = packed.stride(0);
|
||||
p.packed_trans = (Element*)packed_trans.data_ptr();
|
||||
p.packed_trans_stride = packed_trans.stride(0);
|
||||
|
||||
p.threads_masks = (uint64_t*)threads_masks.data_ptr();
|
||||
|
||||
TORCH_CHECK(threads_masks.dim() == 3);
|
||||
TORCH_CHECK(
|
||||
threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
|
||||
TORCH_CHECK(
|
||||
threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
|
||||
TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
|
||||
TORCH_CHECK(threads_masks.stride(2) == 1);
|
||||
TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
|
||||
|
||||
if (!kIsMeta) {
|
||||
size_t smem_bytes = 0;
|
||||
sparse_semi_structured_apply_kernel<KT>
|
||||
<<<p.getBlocksGrid(),
|
||||
p.getThreadsGrid(),
|
||||
smem_bytes,
|
||||
at::cuda::getCurrentCUDAStream()>>>(p);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
return std::make_tuple(packed, packed_trans);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::tuple<Tensor, Tensor> _sparse_semi_structured_apply(const Tensor& input, const Tensor& threads_masks) // Returned by `_sparse_semi_structured_tile`
|
||||
{
|
||||
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||
AT_ERROR("_sparse_semi_structured_apply: not supported");
|
||||
return std::make_tuple(Tensor{}, Tensor{});
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"Unsupported dtype - only `float16` and `bfloat16` are supported currently"
|
||||
);
|
||||
auto result = (input.scalar_type() == at::ScalarType::Half)
|
||||
? _sparse_semi_structured_apply_typed<false, cutlass::half_t>(input, threads_masks)
|
||||
: _sparse_semi_structured_apply_typed<false, cutlass::bfloat16_t>(input, threads_masks);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
@ -1,100 +0,0 @@
|
||||
#pragma once
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
/**
|
||||
* A Functor class to create a sort for fixed sized arrays/containers with a
|
||||
* compile time generated Bose-Nelson sorting network.
|
||||
* \tparam NumElements The number of elements in the array or container to
|
||||
* sort. \tparam T The element type. \tparam Compare A
|
||||
* comparator functor class that returns true if lhs < rhs.
|
||||
*/
|
||||
template <unsigned NumElements>
|
||||
class StaticSort {
|
||||
template <class A>
|
||||
struct Swap {
|
||||
template <class T>
|
||||
CUTLASS_HOST_DEVICE void s(T& v0, T& v1) {
|
||||
// Explicitly code out the Min and Max to nudge the compiler
|
||||
// to generate branchless code.
|
||||
T t = v0 < v1 ? v0 : v1; // Min
|
||||
v1 = v0 < v1 ? v1 : v0; // Max
|
||||
v0 = t;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE Swap(A& a, const int& i0, const int& i1) {
|
||||
s(a[i0], a[i1]);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J, int X, int Y>
|
||||
struct PB {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
enum {
|
||||
L = X >> 1,
|
||||
M = (X & 1 ? Y : Y + 1) >> 1,
|
||||
IAddL = I + L,
|
||||
XSubL = X - L
|
||||
};
|
||||
PB<A, I, J, L, M> p0(a);
|
||||
PB<A, IAddL, J + M, XSubL, Y - M> p1(a);
|
||||
PB<A, IAddL, J, XSubL, M> p2(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J>
|
||||
struct PB<A, I, J, 1, 1> {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
Swap<A> s(a, I - 1, J - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J>
|
||||
struct PB<A, I, J, 1, 2> {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
Swap<A> s0(a, I - 1, J);
|
||||
Swap<A> s1(a, I - 1, J - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int J>
|
||||
struct PB<A, I, J, 2, 1> {
|
||||
CUTLASS_HOST_DEVICE PB(A& a) {
|
||||
Swap<A> s0(a, I - 1, J - 1);
|
||||
Swap<A> s1(a, I, J - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int M, bool Stop = false>
|
||||
struct PS {
|
||||
CUTLASS_HOST_DEVICE PS(A& a) {
|
||||
enum { L = M >> 1, IAddL = I + L, MSubL = M - L };
|
||||
PS<A, I, L, (L <= 1)> ps0(a);
|
||||
PS<A, IAddL, MSubL, (MSubL <= 1)> ps1(a);
|
||||
PB<A, I, IAddL, L, MSubL> pb(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <class A, int I, int M>
|
||||
struct PS<A, I, M, true> {
|
||||
CUTLASS_HOST_DEVICE PS(A& a) {}
|
||||
};
|
||||
|
||||
public:
|
||||
/**
|
||||
* Sorts the array/container arr.
|
||||
* \param arr The array/container to be sorted.
|
||||
*/
|
||||
template <class Container>
|
||||
CUTLASS_HOST_DEVICE void operator()(Container& arr) const {
|
||||
PS<Container, 1, NumElements, (NumElements <= 1)> ps(arr);
|
||||
};
|
||||
|
||||
/**
|
||||
* Sorts the array arr.
|
||||
* \param arr The array to be sorted.
|
||||
*/
|
||||
template <class T>
|
||||
CUTLASS_HOST_DEVICE void operator()(T* arr) const {
|
||||
PS<T*, 1, NumElements, (NumElements <= 1)> ps(arr);
|
||||
};
|
||||
};
|
@ -524,11 +524,8 @@ aten::_sparse_mask_projection.out
|
||||
aten::_sparse_mm_reduce_impl
|
||||
aten::_sparse_mm_reduce_impl_backward
|
||||
aten::_sparse_semi_structured_addmm
|
||||
aten::_sparse_semi_structured_apply
|
||||
aten::_sparse_semi_structured_apply_dense
|
||||
aten::_sparse_semi_structured_linear
|
||||
aten::_sparse_semi_structured_mm
|
||||
aten::_sparse_semi_structured_tile
|
||||
aten::_sparse_softmax
|
||||
aten::_sparse_softmax.out
|
||||
aten::_sparse_softmax_backward_data
|
||||
|
@ -5,7 +5,6 @@ import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.sparse import (
|
||||
SparseSemiStructuredTensor,
|
||||
@ -14,12 +13,6 @@ from torch.sparse import (
|
||||
to_sparse_semi_structured,
|
||||
)
|
||||
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
_sparse_semi_structured_tile,
|
||||
_compute_compressed_swizzled_bitmask,
|
||||
)
|
||||
|
||||
from torch.testing import make_tensor
|
||||
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -39,48 +32,28 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
CUSPARSELT_NUM_ALG_IDS = 4
|
||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = {}
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
||||
|
||||
_IS_SM8X = False
|
||||
|
||||
if torch.cuda.is_available():
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
|
||||
|
||||
# check if cslt is available for now using this:
|
||||
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
||||
try:
|
||||
torch._cslt_compress(torch.ones(128, 256).cuda())
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
|
||||
training_dtypes = dtypes(torch.float16, torch.bfloat16)
|
||||
parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
|
||||
atol_rtol_kw = {
|
||||
torch.float16: {
|
||||
"rtol": 1e-3,
|
||||
"atol": 1e-3,
|
||||
},
|
||||
torch.bfloat16: {
|
||||
"rtol": 1e-1,
|
||||
"atol": 1e-1,
|
||||
},
|
||||
}
|
||||
|
||||
def sparse24_largest_mask_2d(original):
|
||||
sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original)
|
||||
return sparse.to_dense().bool()
|
||||
|
||||
def sparsify24_dense(original):
|
||||
return sparse24_largest_mask_2d(original) * original
|
||||
|
||||
def rand_sparse_semi_structured_mask(
|
||||
r, c, dtype=torch.float16, device="cuda", choice=None
|
||||
@ -124,7 +97,6 @@ def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
|
||||
dense = dense.masked_fill(~mask, 0)
|
||||
return dense
|
||||
|
||||
|
||||
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
|
||||
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
||||
if pattern == '1by2':
|
||||
@ -199,6 +171,8 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
x = x.contiguous()
|
||||
return torch.nn.functional.relu(x)
|
||||
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
|
||||
input = torch.rand(dense_input_shape, device="cuda").half()
|
||||
model = Model().eval().cuda().half()
|
||||
mod_linear = model.linear
|
||||
@ -208,7 +182,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
|
||||
|
||||
dense_result = model(input)
|
||||
mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight))
|
||||
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
|
||||
sparse_result = model(input)
|
||||
|
||||
model = torch.compile(model, backend="inductor", fullgraph=True)
|
||||
@ -239,32 +213,20 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
|
||||
|
||||
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
||||
|
||||
def fn(x, e):
|
||||
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
||||
y = y.t()
|
||||
return x @ y
|
||||
|
||||
# Eager
|
||||
output = fn(x, e)
|
||||
output.backward(output)
|
||||
# Torch compile
|
||||
output = torch.compile(fn)(x, e)
|
||||
output.backward(output)
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_to_sparse_semi_structured(self, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@ -275,14 +237,18 @@ class TestSparseSemiStructured(TestCase):
|
||||
assert isinstance(A, torch.Tensor)
|
||||
assert isinstance(A_sparse, SparseSemiStructuredTensor)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@ -290,6 +256,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
# This should fail
|
||||
if backend == "cutlass":
|
||||
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
@ -302,15 +269,18 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
|
||||
and will throw an error for int8 + padding
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@ -338,9 +308,9 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse.t(), B) throws error
|
||||
@ -359,9 +329,9 @@ class TestSparseSemiStructured(TestCase):
|
||||
):
|
||||
torch.mm(A_sparse.t(), B)
|
||||
|
||||
@inference_dtypes
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse.t()) is correct
|
||||
@ -384,9 +354,9 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse) throws error
|
||||
@ -407,7 +377,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
@parametrize("inference_mode", [subtest(True), subtest(False)])
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_linear(self, dense_input_shape, inference_mode, device, backend):
|
||||
"""
|
||||
Test nn.Linear has the same numerics
|
||||
@ -435,9 +405,11 @@ class TestSparseSemiStructured(TestCase):
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_mlp(self, device, dense_input_shape, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
input = torch.rand(dense_input_shape, device=device).half()
|
||||
model = (
|
||||
nn.Sequential(
|
||||
@ -465,7 +437,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_values(self, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -475,7 +447,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
assert A_sparse.values().shape == (128, 64)
|
||||
assert (A_sparse.values() == 1).all()
|
||||
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_indices(self, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -484,11 +456,16 @@ class TestSparseSemiStructured(TestCase):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
assert A_sparse.indices().shape == (128, 8)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_min_sparse_shape(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
if backend == "cutlass":
|
||||
config = SparseSemiStructuredTensorCUTLASS._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
elif backend == "cusparselt":
|
||||
config = SparseSemiStructuredTensorCUSPARSELT._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
||||
@ -502,8 +479,8 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_res = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@parametrize_backends
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_unsupported_shape(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -513,7 +490,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@dtypes(*all_types_and_complex())
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_unsupported_dtype(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -526,7 +503,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
else:
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@parametrize_backends
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_unsupported_dim(self, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
@ -536,323 +513,13 @@ class TestSparseSemiStructured(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
|
||||
def create_random_mask(shape) -> torch.Tensor:
|
||||
r = random.Random(0)
|
||||
mask = torch.zeros(shape, dtype=torch.bool)
|
||||
for line in range(mask.shape[0]):
|
||||
for col in range(0, mask.shape[1], 4):
|
||||
sparsity = r.choice(
|
||||
[
|
||||
[False, False, True, True],
|
||||
[False, True, False, True],
|
||||
[True, False, False, True],
|
||||
[False, True, True, False],
|
||||
[True, False, True, False],
|
||||
[True, True, False, False],
|
||||
]
|
||||
)
|
||||
mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool)
|
||||
return mask
|
||||
|
||||
class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
|
||||
|
||||
@training_dtypes
|
||||
def test_prune_dense_static_sort(self, dtype) -> None:
|
||||
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
||||
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
||||
dense = torch.randn(128, 128, device="cuda", dtype=dtype)
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
|
||||
# CUTLASS
|
||||
reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
|
||||
assert torch.allclose(pruned, reference_cutlass.to_dense())
|
||||
|
||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
||||
meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride())
|
||||
meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride())
|
||||
compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape,
|
||||
reference_cutlass.compressed_swizzled_bitmask.stride())
|
||||
cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape,
|
||||
packed_cutlass,
|
||||
meta_cutlass,
|
||||
packed_t_cutlass,
|
||||
meta_t_cutlass,
|
||||
compressed_swizzled_bitmask)
|
||||
assert torch.allclose(reference_cutlass.to_dense(), cutlass.to_dense())
|
||||
|
||||
# CUSPARSELT
|
||||
reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
|
||||
algorithm="largest_abs_values_greedy")
|
||||
assert torch.allclose(pruned, reference_cusparselt.to_dense())
|
||||
|
||||
packed_cusparselt = torch._cslt_compress(pruned)
|
||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||
cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape,
|
||||
packed_cusparselt,
|
||||
None,
|
||||
packed_t_cusparselt,
|
||||
None,
|
||||
compressed_swizzled_bitmask)
|
||||
assert torch.allclose(reference_cusparselt.to_dense(), cusparselt.to_dense())
|
||||
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
||||
inp = torch.tensor(
|
||||
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
)
|
||||
inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1)
|
||||
sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy")
|
||||
|
||||
mask = sInp.to_dense() / inp
|
||||
assert mask[:4, :4].int().tolist() == [
|
||||
[1, 1, 0, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[1, 0, 0, 1],
|
||||
]
|
||||
|
||||
@training_dtypes
|
||||
def test_gemm(self, dtype) -> None:
|
||||
M, N, K = 32, 32, 64
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
||||
mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool)
|
||||
|
||||
a.masked_fill_(~mask, 0)
|
||||
|
||||
a_sparse = to_sparse_semi_structured(a)
|
||||
|
||||
masked_a = a * mask
|
||||
ref_out = masked_a @ b
|
||||
sp24_out = a_sparse @ b
|
||||
assert torch.allclose(ref_out, sp24_out, **atol_rtol_kw[dtype])
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
||||
M, N = 128, 256
|
||||
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
||||
a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :]
|
||||
a = a.repeat(M // 8, N // 8)
|
||||
assert a.shape == (M, N)
|
||||
a = a.cuda().to(dtype)
|
||||
b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype)
|
||||
|
||||
a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a)
|
||||
|
||||
mask_dense = sparse24_largest_mask_2d(a).to(dtype)
|
||||
|
||||
if backend == "cutlass":
|
||||
assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS)
|
||||
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(
|
||||
mask_dense, use_cutlass=True)
|
||||
|
||||
sparse_mask = SparseSemiStructuredTensorCUTLASS(
|
||||
mask_dense.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
assert torch.allclose(a_sparse.meta.view(torch.short), sparse_mask.meta)
|
||||
|
||||
ref_gemm = (mask_dense * a) @ b
|
||||
pack_gemm = a_sparse @ b
|
||||
assert torch.allclose(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
|
||||
|
||||
@training_dtypes
|
||||
def test_pack_both_ways_id(self, dtype) -> None:
|
||||
N = 512
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn([N, N], dtype=dtype, device="cuda")
|
||||
b = torch.eye(N, dtype=dtype, device="cuda")
|
||||
|
||||
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[
|
||||
:4
|
||||
]
|
||||
# Heuristic to ensure we pack the same values
|
||||
assert torch.allclose(
|
||||
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
|
||||
)
|
||||
|
||||
mask_dense = sparse24_largest_mask_2d(a.to(dtype))
|
||||
|
||||
ref_gemm = mask_dense * a
|
||||
# Test A@B
|
||||
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
|
||||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
assert torch.allclose(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
# Test A.t@B
|
||||
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
|
||||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
|
||||
assert torch.allclose(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
|
||||
@training_dtypes
|
||||
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
||||
# In this case, the heuristic will keep 7 values out of 16
|
||||
# instead of 8. let's see how the kernel handles this
|
||||
quad = torch.tensor(
|
||||
[
|
||||
[2, -1, -2, -3], # Should be packed as `2 <null>`
|
||||
[-1, 8, -1, 6],
|
||||
[-1, -1, 4, 5],
|
||||
[-1, 3, 7, -1],
|
||||
],
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
)
|
||||
a = torch.randn([32, 64], dtype=dtype, device="cuda")
|
||||
a[:4, :4] = quad
|
||||
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4]
|
||||
# Check first line in A
|
||||
assert packed[0, 0].item() == 2
|
||||
assert packed[0, 1].item() == 0
|
||||
# And first column in A.t
|
||||
assert packed_t[0, 0].item() == 2
|
||||
assert packed_t[0, 1].item() == 0
|
||||
|
||||
@training_dtypes
|
||||
def test_sp24_apply(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
bitmask,
|
||||
) = torch._sparse_semi_structured_tile(x)
|
||||
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
||||
assert torch.allclose(packed, packed2)
|
||||
assert torch.allclose(packed_t, packed_t2)
|
||||
|
||||
@training_dtypes
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
bitmask,
|
||||
) = torch._sparse_semi_structured_tile(x)
|
||||
|
||||
expected = SparseSemiStructuredTensorCUTLASS(
|
||||
x.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
).to_dense()
|
||||
|
||||
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
||||
sparse = SparseSemiStructuredTensorCUTLASS(
|
||||
x.shape,
|
||||
packed=packed2,
|
||||
meta=meta,
|
||||
packed_t=packed_t2,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
|
||||
dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
|
||||
|
||||
assert torch.allclose(dense, expected)
|
||||
assert torch.allclose(sparse.to_dense(), expected)
|
||||
|
||||
|
||||
@training_dtypes
|
||||
def test_sp24_matmuls(self, dtype) -> None:
|
||||
M, N, K = 64, 256, 1024
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
||||
a_m = sparse24_largest_mask_2d(a)
|
||||
b_m = sparse24_largest_mask_2d(b)
|
||||
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a)
|
||||
a_s = SparseSemiStructuredTensorCUTLASS(
|
||||
a.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b)
|
||||
b_s = SparseSemiStructuredTensorCUTLASS(
|
||||
b.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=bitmask,
|
||||
)
|
||||
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1e-1)
|
||||
assert torch.allclose(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1e-1)
|
||||
assert torch.allclose(
|
||||
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1e-1
|
||||
)
|
||||
assert torch.allclose(
|
||||
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
|
||||
)
|
||||
|
||||
def test_sp24_matmuls_mat_vec(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
||||
a_m = sparse24_largest_mask_2d(a)
|
||||
a_s = to_sparse_semi_structured(a)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
||||
a_m = sparse24_largest_mask_2d(a)
|
||||
a_s = to_sparse_semi_structured(a)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
"""
|
||||
This contains CUTLASS specific tests for
|
||||
- torch._sparse_semi_structured_linear
|
||||
"""
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
self.skipTest('CUTLASS not enabled')
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@inference_dtypes
|
||||
def test_linear_cutlass(self, device, dtype):
|
||||
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_linear_cutlass(self, device, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
|
||||
weight = rand_sparse_semi_structured(m, k, dtype, device)
|
||||
@ -976,8 +643,12 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
||||
@inference_dtypes
|
||||
def test_conversions(self, device, dtype):
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_conversions(self, device, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
|
||||
def run_test(r, c, device, dtype):
|
||||
dense_ref = rand_sparse_semi_structured(r, c, dtype, device)
|
||||
@ -1004,8 +675,12 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
run_test(r, c, device, dtype)
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
||||
@inference_dtypes
|
||||
def test_conversions_all_patterns(self, device, dtype):
|
||||
@parametrize("backend", ["cutlass"])
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_conversions_all_patterns(self, device, dtype, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
if backend == "cutlass" and IS_WINDOWS:
|
||||
self.skipTest("CUTLASS not supported on Windows")
|
||||
r, c = 32, 128
|
||||
|
||||
dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
|
||||
@ -1015,23 +690,18 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
|
||||
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
||||
|
||||
|
||||
|
||||
CUSPARSELT_NUM_ALG_IDS = 4
|
||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
|
||||
|
||||
class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
class TestCUSPARSELT(TestCase):
|
||||
"""
|
||||
This contains cuSPARSELt specific tests for
|
||||
torch._cslt_compress
|
||||
torch._cslt_sparse_mm
|
||||
This contains cuSPARSELt specific tests.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
self.skipTest('cuSPARSELt not enabled')
|
||||
else:
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||
|
||||
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
||||
@parametrize("dense_input_shape", [(128, 128)])
|
||||
@ -1045,7 +715,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@training_dtypes
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
def test_cslt_sparse_mm_alpha(self, dtype, device):
|
||||
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
|
||||
B = torch.ones((256, 128), device=device).to(dtype)
|
||||
@ -1077,7 +747,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
|
||||
@inference_dtypes
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
|
||||
# alg_id=3 not supported for float32 dtype
|
||||
if dtype == torch.float32 and alg_id == 3:
|
||||
@ -1094,7 +764,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_cslt_sparse_mm_search(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
@ -1107,10 +777,9 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
||||
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestCUSPARSELT, globals(), only_for="cuda")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,22 +1,20 @@
|
||||
import torch
|
||||
|
||||
|
||||
# This is PyTorch implementation of main part of reorder_meta()
|
||||
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
||||
# GEMM decides upon layout of this matrix, and at the moment for the
|
||||
# sparse GEMM executed on tensor cores, this is layout described by
|
||||
# ColumnMajorInterleaved<2> data structure, in
|
||||
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
||||
# reordering of meta matrix into meta_reordered matrix calculated
|
||||
# according to these segments of CUTLASS code is re-implemented here.
|
||||
# Note that this calculation produces offsets for scattering metadata
|
||||
# matrix elements into reordered metadata matrix elements (or,
|
||||
# equivalently, for gathering reordered metadata matrix element back
|
||||
# into metadata matrix elements).
|
||||
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
||||
"""
|
||||
This is PyTorch implementation of main part of reorder_meta()
|
||||
function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
||||
GEMM decides upon layout of this matrix, and at the moment for the
|
||||
sparse GEMM executed on tensor cores, this is layout described by
|
||||
ColumnMajorInterleaved<2> data structure, in
|
||||
include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
||||
reordering of meta matrix into meta_reordered matrix calculated
|
||||
according to these segments of CUTLASS code is re-implemented here.
|
||||
Note that this calculation produces offsets for scattering metadata
|
||||
matrix elements into reordered metadata matrix elements (or,
|
||||
equivalently, for gathering reordered metadata matrix element back
|
||||
into metadata matrix elements).
|
||||
"""
|
||||
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
||||
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
||||
|
||||
@ -43,12 +41,10 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device
|
||||
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
||||
|
||||
|
||||
# This function converts dense matrix into sparse semi-structured
|
||||
# representation, producing "compressed" matrix, in the layout used by
|
||||
# CUTLASS backend, and corresponding metadata matrix.
|
||||
def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
"""
|
||||
This function converts dense matrix into sparse semi-structured
|
||||
representation, producing "compressed" matrix, in the layout used by
|
||||
CUTLASS backend, and corresponding metadata matrix.
|
||||
"""
|
||||
if dense.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
|
||||
@ -176,13 +172,11 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||
|
||||
|
||||
# This function performs reverse of the function above - it
|
||||
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
# in the layout used by CUTLASS backend, and accompanying metadata
|
||||
# matrix.
|
||||
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
"""
|
||||
This function performs reverse of the function above - it
|
||||
reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
in the layout used by CUTLASS backend, and accompanying metadata
|
||||
matrix.
|
||||
"""
|
||||
if sparse.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
|
||||
@ -279,73 +273,3 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
)
|
||||
|
||||
return dense.view(m, 2 * k)
|
||||
|
||||
|
||||
def _sparse_semi_structured_tile(dense):
|
||||
"""
|
||||
This function computes a 2:4 sparse tile by greedily taking the largest values.
|
||||
|
||||
Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
|
||||
the ultimate sparsity pattern.
|
||||
|
||||
Note that this function does not have the same sorting semantics as our CUDA backend,
|
||||
which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
|
||||
"""
|
||||
|
||||
def greedy_prune_tile(tile):
|
||||
num_kept_row = [0, 0, 0, 0]
|
||||
num_kept_col = [0, 0, 0, 0]
|
||||
|
||||
for x in tile.flatten().sort(descending=True, stable=True).indices:
|
||||
r, c = x // 4, x % 4
|
||||
if num_kept_row[r] < 2 and num_kept_col[c] < 2:
|
||||
num_kept_row[r] += 1
|
||||
num_kept_col[c] += 1
|
||||
else:
|
||||
tile[r, c] = 0
|
||||
|
||||
for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
|
||||
for tile in batch:
|
||||
greedy_prune_tile(tile)
|
||||
|
||||
return dense
|
||||
|
||||
|
||||
def _compute_compressed_swizzled_bitmask(dense):
|
||||
"""
|
||||
Calculates the compressed swizzled bitmask from a dense tensor
|
||||
"""
|
||||
|
||||
# first we need to convert the dense tensor to a bitmask
|
||||
int_bitmask = dense.bool().to(torch.uint8)
|
||||
|
||||
# Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
|
||||
# A, B, C and D, as displayed in the following schema:
|
||||
# +---+---+
|
||||
# | A | B |
|
||||
# +---+---+
|
||||
# | C | D |
|
||||
# +---+---+
|
||||
|
||||
# we first need to split into the 8x8 tiles
|
||||
bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
|
||||
|
||||
# then we unfold again to get our indivdual 4x4 tiles
|
||||
bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
|
||||
|
||||
# Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
|
||||
# of that tile. Note that the least siginificant bit is stored first.
|
||||
# [1 1 0 0]
|
||||
# [1 1 0 0] -> 0011 0011 -> 51
|
||||
# [0 0 1 1] 1100 1100 204
|
||||
# [0 0 1 1]
|
||||
|
||||
# reshape tensor to expand tiles into 8-bit vectors
|
||||
bitmask_binary_representation = bitmask_4x4_chunks.reshape(*bitmask_4x4_chunks.shape[:2], 4, 2, 8)
|
||||
|
||||
# to convert from binary representaiton, we can do a matmul with powers of two
|
||||
powers_of_two = 2**torch.arange(8, dtype=torch.float, device="cuda")
|
||||
# To run on GPU: cast to float to do matmul and then cast back
|
||||
compressed_swizzled_bitmask = (bitmask_binary_representation.to(torch.float) @ powers_of_two).to(torch.uint8)
|
||||
|
||||
return compressed_swizzled_bitmask
|
||||
|
@ -70,8 +70,8 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||
meta=self.meta_t,
|
||||
packed_t=self.packed,
|
||||
meta_t=self.meta,
|
||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
|
||||
if self.compressed_swizzled_bitmask is not None
|
||||
threads_masks=self.threads_masks.transpose(0, 1)
|
||||
if self.threads_masks is not None
|
||||
else None,
|
||||
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=args[0].alg_id_cusparselt,
|
||||
@ -97,7 +97,7 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
|
||||
meta=self.meta,
|
||||
packed_t=self.packed_t,
|
||||
meta_t=self.meta_t,
|
||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
|
||||
threads_masks=self.threads_masks,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, List, Callable, Dict
|
||||
import torch
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass
|
||||
sparse_semi_structured_to_dense_cutlass,
|
||||
)
|
||||
from torch.sparse._semi_structured_ops import (
|
||||
fallback_dispatcher,
|
||||
@ -56,18 +56,17 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
_FUSE_TRANSPOSE: bool = False
|
||||
_PROTOTYPE_WARNING_SHOWN: bool = False
|
||||
|
||||
BACKEND: str
|
||||
SPARSE_DISPATCH: Dict[Callable, Callable]
|
||||
|
||||
packed: Optional[torch.Tensor]
|
||||
meta: Optional[torch.Tensor]
|
||||
packed_t: Optional[torch.Tensor]
|
||||
meta_t: Optional[torch.Tensor]
|
||||
compressed_swizzled_bitmask: Optional[torch.Tensor]
|
||||
threads_masks: Optional[torch.Tensor]
|
||||
fuse_transpose_cusparselt: bool
|
||||
alg_id_cusparselt: int
|
||||
|
||||
__slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
|
||||
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"]
|
||||
|
||||
@staticmethod
|
||||
def __new__( # noqa: PYI034
|
||||
@ -77,7 +76,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta: Optional[torch.Tensor],
|
||||
packed_t: Optional[torch.Tensor],
|
||||
meta_t: Optional[torch.Tensor],
|
||||
compressed_swizzled_bitmask: Optional[torch.Tensor],
|
||||
threads_masks: Optional[torch.Tensor],
|
||||
fuse_transpose_cusparselt: bool = False,
|
||||
alg_id_cusparselt: int = 0,
|
||||
requires_grad: bool = False,
|
||||
@ -96,8 +95,8 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta: The metadata of the original dense tensor, if it is stored separately
|
||||
packed_t: The compressed representation of the transposed original dense tensor
|
||||
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
|
||||
compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
|
||||
participate in the computation. Used for pointwise ops.
|
||||
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation.
|
||||
Used for pointwise ops.
|
||||
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
|
||||
with a matmul, which is useful in the case of 2:4 sparse training.
|
||||
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
|
||||
@ -125,9 +124,6 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
|
||||
cls._load_dispatch_table()
|
||||
|
||||
# we can also register the classes with dynamo when the warning is shown.
|
||||
torch._dynamo.allow_in_graph(cls)
|
||||
|
||||
if packed is not None:
|
||||
previous_tensor = packed
|
||||
elif packed_t is not None:
|
||||
@ -147,7 +143,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
tensor.meta = meta
|
||||
tensor.packed_t = packed_t
|
||||
tensor.meta_t = meta_t
|
||||
tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
|
||||
tensor.threads_masks = threads_masks
|
||||
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
|
||||
tensor.alg_id_cusparselt = alg_id_cusparselt
|
||||
return tensor
|
||||
@ -185,7 +181,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta=inner_tensors.get("meta", None),
|
||||
packed_t=inner_tensors.get("packed_t", None),
|
||||
meta_t=inner_tensors.get("meta_t", None),
|
||||
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
|
||||
threads_masks=inner_tensors.get("threads_masks", None),
|
||||
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=alg_id_cusparselt,
|
||||
requires_grad=requires_grad,
|
||||
@ -220,7 +216,6 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
torch.ops.aten.matmul: semi_sparse_mm,
|
||||
torch.ops.aten.addmm: semi_sparse_addmm,
|
||||
torch.ops.aten.linear: semi_sparse_linear,
|
||||
torch.ops.aten._to_copy: fallback_dispatcher,
|
||||
}
|
||||
if custom_dispatch_table is not None:
|
||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||
@ -364,14 +359,13 @@ def to_sparse_semi_structured(
|
||||
"SparseSemiStructuredTensor only support contiguous input tensors. "
|
||||
)
|
||||
|
||||
# set from _FORCE_CUTLASS flag
|
||||
SPARSE_SUBCLASS = (
|
||||
sparse_subclass = (
|
||||
torch.sparse.SparseSemiStructuredTensorCUTLASS
|
||||
if SparseSemiStructuredTensor._FORCE_CUTLASS
|
||||
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
|
||||
)
|
||||
return sparse_subclass.from_dense(original_tensor)
|
||||
|
||||
return SPARSE_SUBCLASS.from_dense(original_tensor)
|
||||
|
||||
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
"""
|
||||
@ -384,7 +378,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
|
||||
sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
"""
|
||||
BACKEND = "cutlass"
|
||||
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
||||
@ -407,71 +401,19 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
meta=meta_tensor_cutlass,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
compressed_swizzled_bitmask=None,
|
||||
threads_masks=None,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
def to_dense(self):
|
||||
assert self.meta is not None and self.packed is not None
|
||||
return sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
) if self.meta.ndim == 2 else super().to_dense()
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
|
||||
|
||||
It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
|
||||
The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
|
||||
|
||||
Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
|
||||
It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
|
||||
pruned dense tensor.
|
||||
Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
|
||||
|
||||
Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
|
||||
This can be used in the backward pass to mask the gradients.
|
||||
|
||||
[9 1 7 4] [9 0 7 0]
|
||||
[1 2 3 0] [0 2 0 0]
|
||||
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
|
||||
[1 2 6 2] [0 0 6 2] -> metadata
|
||||
|
||||
-> pack to transposed CUTLASS -> packed_t
|
||||
semi-structured representation -> metadata_t
|
||||
|
||||
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
||||
|
||||
|
||||
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUTLASS
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
|
||||
```
|
||||
"""
|
||||
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=True)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
||||
requires_grad=False,
|
||||
return (
|
||||
sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
)
|
||||
if self.meta.ndim == 2
|
||||
else super().to_dense()
|
||||
)
|
||||
|
||||
def _mm(
|
||||
@ -517,7 +459,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
|
||||
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
|
||||
"""
|
||||
BACKEND = "cusparselt"
|
||||
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||
@ -534,59 +476,12 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
meta=None,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
compressed_swizzled_bitmask=None,
|
||||
threads_masks=None,
|
||||
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
|
||||
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
|
||||
layout and sparse matmul.
|
||||
|
||||
The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
|
||||
|
||||
[9 1 7 4] [9 0 7 0]
|
||||
[1 2 3 0] [0 2 0 0]
|
||||
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
|
||||
[1 2 6 2] [0 0 6 2]
|
||||
|
||||
-> pack to transposed cuSPARSELt -> packed_t
|
||||
semi-structured representation
|
||||
|
||||
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
||||
|
||||
|
||||
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cusparselt = torch._cslt_compress(pruned)
|
||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
||||
```
|
||||
"""
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=False)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user