mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (5,064 ignored) Only four directories left to enable Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877 Approved by: https://github.com/oulgen
359 lines
14 KiB
Python
359 lines
14 KiB
Python
# mypy: allow-untyped-defs
|
|
import torch
|
|
|
|
|
|
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
|
"""
|
|
This is PyTorch implementation of main part of reorder_meta()
|
|
function, from tools/util/include/cutlass/util/host_reorder.h file
|
|
of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
|
GEMM decides upon layout of this matrix, and at the moment for the
|
|
sparse GEMM executed on tensor cores, this is layout described by
|
|
ColumnMajorInterleaved<2> data structure, in
|
|
include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
|
reordering of meta matrix into meta_reordered matrix calculated
|
|
according to these segments of CUTLASS code is re-implemented here.
|
|
Note that this calculation produces offsets for scattering metadata
|
|
matrix elements into reordered metadata matrix elements (or,
|
|
equivalently, for gathering reordered metadata matrix element back
|
|
into metadata matrix elements).
|
|
"""
|
|
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
|
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
|
|
|
# Reorder the rows, then swizzle the 2x2 blocks.
|
|
group = 32 if meta_dtype.itemsize == 2 else 16
|
|
interweave = 4 if meta_dtype.itemsize == 2 else 2
|
|
dst_rows = (
|
|
dst_rows // group * group
|
|
+ (dst_rows % 8) * interweave
|
|
+ (dst_rows % group) // 8
|
|
)
|
|
|
|
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
|
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
|
dst_rows += topright - bottomleft
|
|
dst_cols -= topright - bottomleft
|
|
|
|
# Assumed that meta tensor is to be stored in CUTLASS
|
|
# InterleavedColumnMajor layout, and reverse engineered
|
|
# corresponding code to store values into this tensor.
|
|
interleave = 2
|
|
cols_maj = dst_cols // interleave
|
|
cols_min = dst_cols % interleave
|
|
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
|
|
|
|
|
def sparse_semi_structured_from_dense_cutlass(dense):
|
|
"""
|
|
This function converts dense matrix into sparse semi-structured
|
|
representation, producing "compressed" matrix, in the layout used by
|
|
CUTLASS backend, and corresponding metadata matrix.
|
|
"""
|
|
if dense.dim() != 2:
|
|
raise RuntimeError(
|
|
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
|
|
)
|
|
|
|
m, k = dense.shape
|
|
device = dense.device
|
|
|
|
meta_dtype = torch.int8
|
|
if dense.dtype == torch.int8:
|
|
meta_dtype = torch.int32
|
|
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
meta_dtype = torch.int16
|
|
else:
|
|
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
|
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
|
if quadbits_per_meta_elem not in (4, 8):
|
|
raise RuntimeError("Invalid number of elements per meta element calculated")
|
|
|
|
if meta_dtype == torch.int32:
|
|
if m % 16 != 0:
|
|
raise RuntimeError(
|
|
f"Number of rows of dense matrix {m} must be divisible by 16"
|
|
)
|
|
else:
|
|
if m % 32 != 0:
|
|
raise RuntimeError(
|
|
f"Number of rows of dense matrix {m} must be divisible by 32"
|
|
)
|
|
if k % (4 * quadbits_per_meta_elem) != 0:
|
|
raise RuntimeError(
|
|
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
|
|
)
|
|
|
|
if dense.dtype != torch.float:
|
|
ksparse = 4
|
|
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
|
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
|
|
else:
|
|
ksparse = 2
|
|
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
|
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
|
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
|
|
|
# Encoding quadruples of True/False values as follows:
|
|
# [True, True, False, False] -> 0b0100
|
|
# [True, False, True, False] -> 0b1000
|
|
# [False, True, True, False] -> 0b1001
|
|
# [True, False, False, True ] -> 0b1100
|
|
# [False, True, False, True ] -> 0b1101
|
|
# [False, False, True, True ] -> 0b1110
|
|
# Thus, lower two bits in the encoding are index of the True value
|
|
# at the lowest index in the quadruple, and the higher two bits in
|
|
# the encoding are index of the other True value in the quadruple.
|
|
# In case there are less than two True values, than False value or
|
|
# values at some index or indices are considered True for the
|
|
# encoding. In case there are more than two True values, then the
|
|
# excess True value(s) at some indices are considered False for
|
|
# the encoding. The exact encodings used for these cases are as
|
|
# follows:
|
|
# [False, False, False, False] -> 0b1110
|
|
# [False, False, False, True ] -> 0b1110
|
|
# [False, False, True, False] -> 0b1110
|
|
# [False, True, False, False] -> 0b1001
|
|
# [False, True, True, True ] -> 0b1101
|
|
# [True, False, False, False] -> 0b1000
|
|
# [True, False, True, True ] -> 0b1100
|
|
# [True, True, False, True ] -> 0b0100
|
|
# [True, True, True, False] -> 0b0100
|
|
# [True, True, True, True ] -> 0b0100
|
|
# These particular encodings are chosen, with the help of Espresso
|
|
# logic minimizer software, for the purpose of minimization of
|
|
# corresponding Boolean functions, that translate non-zero flags
|
|
# into encoding bits. Note also possible choices for the first
|
|
# and last of these encodings were limited only to (0b0100,
|
|
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
|
# case.
|
|
|
|
expr0 = m0 & m1
|
|
expr1 = ~m0 & m1
|
|
expr2 = ~m0 & ~m1
|
|
bit0 = expr1
|
|
bit1 = expr2
|
|
bit2 = expr0 | expr2 | m3
|
|
bit3 = expr1 | ~m1
|
|
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
|
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
|
|
|
if dense.dtype != torch.float:
|
|
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
|
# pyrefly: ignore # unbound-name
|
|
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
|
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
|
else:
|
|
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
|
|
|
|
meta_4 = idxs0 | (idxs1 << 2)
|
|
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
|
|
|
if quadbits_per_meta_elem == 4:
|
|
meta = (
|
|
meta_n[:, :, 0]
|
|
| (meta_n[:, :, 1] << 4)
|
|
| (meta_n[:, :, 2] << 8)
|
|
| (meta_n[:, :, 3] << 12)
|
|
)
|
|
elif quadbits_per_meta_elem == 8:
|
|
meta = (
|
|
meta_n[:, :, 0]
|
|
| (meta_n[:, :, 1] << 4)
|
|
| (meta_n[:, :, 2] << 8)
|
|
| (meta_n[:, :, 3] << 12)
|
|
| (meta_n[:, :, 4] << 16)
|
|
| (meta_n[:, :, 5] << 20)
|
|
| (meta_n[:, :, 6] << 24)
|
|
| (meta_n[:, :, 7] << 28)
|
|
)
|
|
|
|
# Reorder meta tensor elements.
|
|
meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined]
|
|
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
|
m, meta_ncols, meta_dtype, device
|
|
)
|
|
# pyrefly: ignore # unbound-name
|
|
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
|
|
|
return (sparse, meta_reordered.view(m, meta_ncols))
|
|
|
|
|
|
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
|
"""
|
|
This function performs reverse of the function above - it
|
|
reconstructs dense matrix from a pair of "compressed" matrix, given
|
|
in the layout used by CUTLASS backend, and accompanying metadata
|
|
matrix.
|
|
"""
|
|
if sparse.dim() != 2:
|
|
raise RuntimeError(
|
|
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
|
|
)
|
|
|
|
m, k = sparse.shape
|
|
device = sparse.device
|
|
|
|
if meta_reordered.dim() != 2:
|
|
raise RuntimeError(
|
|
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"
|
|
)
|
|
if meta_reordered.device != device:
|
|
raise RuntimeError(
|
|
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"
|
|
)
|
|
|
|
meta_dtype = meta_reordered.dtype
|
|
if meta_dtype not in (torch.int16, torch.int32):
|
|
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
|
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
|
|
|
if sparse.dtype != torch.float:
|
|
ksparse = 4
|
|
else:
|
|
ksparse = 2
|
|
|
|
meta_nrows, meta_ncols = meta_reordered.shape
|
|
if meta_nrows != m:
|
|
raise RuntimeError(
|
|
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"
|
|
)
|
|
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
|
raise RuntimeError(
|
|
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, "
|
|
"expected according to the number of columns of meta matrix"
|
|
)
|
|
|
|
# Undo meta tensor elements reordering.
|
|
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
|
m, meta_ncols, meta_dtype, device
|
|
)
|
|
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
|
|
|
# Unpack sparse tensor back to original dense tensor, using
|
|
# information provided by meta tensor. Note that torch.float
|
|
# datatype is handled pretty much the same as
|
|
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
|
# value is encoded as if underlying 8 bytes contain four
|
|
# torch.half/torch.bfloat16 values, where either first two or last
|
|
# two are zeros.
|
|
meta_2 = torch.empty(
|
|
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
|
dtype=meta_dtype,
|
|
device=device,
|
|
)
|
|
if quadbits_per_meta_elem == 4:
|
|
meta_2[:, :, 0] = meta & 0b11
|
|
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
|
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
|
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
|
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
|
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
|
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
|
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
|
elif quadbits_per_meta_elem == 8:
|
|
meta_2[:, :, 0] = meta & 0b11
|
|
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
|
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
|
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
|
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
|
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
|
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
|
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
|
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
|
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
|
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
|
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
|
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
|
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
|
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
|
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
|
|
|
dense_offsets = meta_2.view(-1) + (
|
|
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
|
).view(-1, 1).repeat(1, 2).view(-1)
|
|
|
|
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
|
if sparse.dtype != torch.float:
|
|
dense.scatter_(0, dense_offsets, sparse.view(-1))
|
|
else:
|
|
dense.view(torch.half).scatter_(
|
|
0, dense_offsets, sparse.view(torch.half).view(-1)
|
|
)
|
|
|
|
return dense.view(m, 2 * k)
|
|
|
|
|
|
def _sparse_semi_structured_tile(dense):
|
|
"""
|
|
This function computes a 2:4 sparse tile by greedily taking the largest values.
|
|
|
|
Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
|
|
the ultimate sparsity pattern.
|
|
|
|
Note that this function does not have the same sorting semantics as our CUDA backend,
|
|
which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
|
|
"""
|
|
|
|
def greedy_prune_tile(tile):
|
|
num_kept_row = [0, 0, 0, 0]
|
|
num_kept_col = [0, 0, 0, 0]
|
|
|
|
for x in tile.flatten().sort(descending=True, stable=True).indices:
|
|
r, c = x // 4, x % 4
|
|
if num_kept_row[r] < 2 and num_kept_col[c] < 2:
|
|
num_kept_row[r] += 1
|
|
num_kept_col[c] += 1
|
|
else:
|
|
tile[r, c] = 0
|
|
|
|
for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
|
|
for tile in batch:
|
|
greedy_prune_tile(tile)
|
|
|
|
return dense
|
|
|
|
|
|
def _compute_compressed_swizzled_bitmask(dense):
|
|
"""
|
|
Calculates the compressed swizzled bitmask from a dense tensor
|
|
"""
|
|
|
|
# first we need to convert the dense tensor to a bitmask
|
|
int_bitmask = dense.bool().to(torch.uint8)
|
|
|
|
# Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
|
|
# A, B, C and D, as displayed in the following schema:
|
|
# +---+---+
|
|
# | A | B |
|
|
# +---+---+
|
|
# | C | D |
|
|
# +---+---+
|
|
|
|
# we first need to split into the 8x8 tiles
|
|
bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
|
|
|
|
# then we unfold again to get our individual 4x4 tiles
|
|
bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
|
|
|
|
# Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
|
|
# of that tile. Note that the least significant bit is stored first.
|
|
# [1 1 0 0]
|
|
# [1 1 0 0] -> 0011 0011 -> 51
|
|
# [0 0 1 1] 1100 1100 204
|
|
# [0 0 1 1]
|
|
|
|
# reshape tensor to expand tiles into 8-bit vectors
|
|
bitmask_binary_representation = bitmask_4x4_chunks.reshape(
|
|
*bitmask_4x4_chunks.shape[:2], 4, 2, 8
|
|
)
|
|
|
|
# to convert from binary representation, we can do a matmul with powers of two
|
|
powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
|
|
# To run on GPU: cast to float to do matmul and then cast back
|
|
compressed_swizzled_bitmask = (
|
|
bitmask_binary_representation.to(torch.float) @ powers_of_two
|
|
).to(torch.uint8)
|
|
|
|
return compressed_swizzled_bitmask
|