From 535016967ae65a6027f83d6b935a985996223d49 Mon Sep 17 00:00:00 2001 From: WeiChunyu-star Date: Mon, 15 Jul 2024 22:35:52 +0000 Subject: [PATCH] Enable UFMT on all of torch/sparse (#130545) Partially addresses #123062 Ran lintrunner on: - torch/sparse Detail: ``` $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/130545 Approved by: https://github.com/ezyang --- .lintrunner.toml | 4 - torch/sparse/__init__.py | 190 ++-- torch/sparse/_semi_structured_conversions.py | 10 +- torch/sparse/_triton_ops.py | 961 +++++++++++++------ torch/sparse/semi_structured.py | 95 +- 5 files changed, 884 insertions(+), 376 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 355b0a925e95..67c1fd0d96c2 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1531,10 +1531,6 @@ exclude_patterns = [ 'torch/signal/__init__.py', 'torch/signal/windows/__init__.py', 'torch/signal/windows/windows.py', - 'torch/sparse/__init__.py', - 'torch/sparse/_semi_structured_conversions.py', - 'torch/sparse/_triton_ops.py', - 'torch/sparse/semi_structured.py', 'torch/special/__init__.py', 'torch/testing/_internal/__init__.py', 'torch/testing/_internal/autocast_test_lists.py', diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 5b86e068096f..8b3a1f2e2b2d 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -1,23 +1,23 @@ # mypy: allow-untyped-defs # The Tensor classes are added to this module by python_tensor.cpp -from typing import Optional, Tuple, List, Union, Any +# A workaround to support both TorchScript and MyPy: +from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] from torch import Tensor +from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] # Semi structured sparsity support from .semi_structured import ( SparseSemiStructuredTensor, SparseSemiStructuredTensorCUSPARSELT, SparseSemiStructuredTensorCUTLASS, - to_sparse_semi_structured + to_sparse_semi_structured, ) -# A workaround to support both TorchScript and MyPy: -from typing import TYPE_CHECKING if TYPE_CHECKING: from torch.types import _dtype as DType + DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]] else: # The JIT doesn't understand Union, nor torch.dtype here @@ -26,20 +26,22 @@ else: __all__ = [ - 'addmm', - 'check_sparse_tensor_invariants', - 'mm', - 'sum', - 'softmax', - 'log_softmax', - 'SparseSemiStructuredTensor', - 'SparseSemiStructuredTensorCUTLASS', - 'SparseSemiStructuredTensorCUSPARSELT', - 'to_sparse_semi_structured', - 'as_sparse_gradcheck', + "addmm", + "check_sparse_tensor_invariants", + "mm", + "sum", + "softmax", + "log_softmax", + "SparseSemiStructuredTensor", + "SparseSemiStructuredTensorCUTLASS", + "SparseSemiStructuredTensorCUSPARSELT", + "to_sparse_semi_structured", + "as_sparse_gradcheck", ] -addmm = _add_docstr(_sparse._sparse_addmm, r""" +addmm = _add_docstr( + _sparse._sparse_addmm, + r""" sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor This function does exact same thing as :func:`torch.addmm` in the forward, @@ -58,10 +60,13 @@ Args: mat2 (Tensor): a dense matrix to be multiplied beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) -""") +""", +) -mm = _add_docstr(_sparse._sparse_mm, r""" +mm = _add_docstr( + _sparse._sparse_mm, + r""" Performs a matrix multiplication of the sparse matrix :attr:`mat1` and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a @@ -132,10 +137,13 @@ Example:: >>> y2 tensor([[0., 1.], [6., 0.]], grad_fn=) -""") +""", +) -sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r""" +sampled_addmm = _add_docstr( + _sparse.sparse_sampled_addmm, + r""" sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations @@ -184,10 +192,11 @@ Examples:: col_indices=tensor([0, 1, 2]), values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr) -""") +""", +) -def sum(input: Tensor, dim: DimOrDims = None, - dtype: Optional[DType] = None) -> Tensor: + +def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor: r"""Return the sum of each row of the given sparse tensor. Returns the sum of each row of the sparse tensor :attr:`input` in the given @@ -256,7 +265,9 @@ def sum(input: Tensor, dim: DimOrDims = None, return torch._sparse_sum(input, dtype=dtype) -softmax = _add_docstr(_sparse._sparse_softmax, r""" +softmax = _add_docstr( + _sparse._sparse_softmax, + r""" sparse.softmax(input, dim, *, dtype=None) -> Tensor Applies a softmax function. @@ -281,10 +292,13 @@ Args: casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None -""") +""", +) -log_softmax = _add_docstr(_sparse._sparse_log_softmax, r""" +log_softmax = _add_docstr( + _sparse._sparse_log_softmax, + r""" sparse.log_softmax(input, dim, *, dtype=None) -> Tensor Applies a softmax function followed by logarithm. @@ -299,7 +313,8 @@ Args: casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None -""") +""", +) spdiags = _add_docstr( @@ -393,7 +408,8 @@ Specifying a positive offset:: [0, 0, 3, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) -""") +""", +) class check_sparse_tensor_invariants: @@ -483,12 +499,14 @@ class check_sparse_tensor_invariants: # context manager support def __init__(self, enable=True): self.state = enable - self.saved_state : Optional[bool] = None + self.saved_state: Optional[bool] = None def __enter__(self): if self.saved_state is not None: - raise RuntimeError('This context manager instance is already activated.' - ' Use a different context manager instance for context nesting.') + raise RuntimeError( + "This context manager instance is already activated." + " Use a different context manager instance for context nesting." + ) self.saved_state = self.is_enabled() torch._C._set_check_sparse_tensor_invariants(self.state) @@ -499,7 +517,6 @@ class check_sparse_tensor_invariants: # decorator support def __call__(self, mth): - def test_mth(*args, **kwargs): with type(self)(self.state): return mth(*args, **kwargs) @@ -531,37 +548,71 @@ def as_sparse_gradcheck(gradcheck): Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support. """ - masked = kwargs.pop('masked', False) - sparse_layouts = {torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} - sparse_compressed_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} + masked = kwargs.pop("masked", False) + sparse_layouts = { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + sparse_compressed_layouts = { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc} - STRIDED_REPRESENTATION = '__STRIDED_REPRESENTATION__' + STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__" def convert_to_strided_representation(args): """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors.""" if not isinstance(args, (list, tuple)): - args = args, + args = (args,) new_args: List[Any] = [] for obj in args: - if isinstance(obj, torch.Tensor) and obj.requires_grad and obj.layout in sparse_layouts: + if ( + isinstance(obj, torch.Tensor) + and obj.requires_grad + and obj.layout in sparse_layouts + ): d = dict(layout=obj.layout, shape=obj.shape) if not masked: # Materialize unspecified elements with zero values batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim() - blocksize = obj.values().shape[batch_dim + 1:batch_dim + 3] if obj.layout in sparse_block_layouts else None - full_mask = torch.ones(obj.shape, device=obj.device, dtype=torch.bool).to_sparse( - layout=obj.layout, blocksize=blocksize, dense_dim=obj.dense_dim()) + blocksize = ( + obj.values().shape[batch_dim + 1 : batch_dim + 3] + if obj.layout in sparse_block_layouts + else None + ) + full_mask = torch.ones( + obj.shape, device=obj.device, dtype=torch.bool + ).to_sparse( + layout=obj.layout, + blocksize=blocksize, + dense_dim=obj.dense_dim(), + ) obj = obj.to_dense().sparse_mask(full_mask) if obj.layout is torch.sparse_coo: - d.update(indices=obj._indices(), is_coalesced=obj.is_coalesced()) + d.update( + indices=obj._indices(), is_coalesced=obj.is_coalesced() + ) values = obj._values() elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: - d.update(compressed_indices=obj.crow_indices(), plain_indices=obj.col_indices()) + d.update( + compressed_indices=obj.crow_indices(), + plain_indices=obj.col_indices(), + ) values = obj.values() else: - d.update(compressed_indices=obj.ccol_indices(), plain_indices=obj.row_indices()) + d.update( + compressed_indices=obj.ccol_indices(), + plain_indices=obj.row_indices(), + ) values = obj.values() - new_args.extend((STRIDED_REPRESENTATION, d, values.requires_grad_(True))) + new_args.extend( + (STRIDED_REPRESENTATION, d, values.requires_grad_(True)) + ) else: new_args.append(obj) return tuple(new_args) @@ -574,13 +625,25 @@ def as_sparse_gradcheck(gradcheck): a = args.pop(0) if a == STRIDED_REPRESENTATION: d, values = args.pop(0), args.pop(0) - if d['layout'] is torch.sparse_coo: - a = torch.sparse_coo_tensor(d['indices'], values, size=d['shape'], is_coalesced=d['is_coalesced']) - elif d['layout'] in sparse_compressed_layouts: - a = torch.sparse_compressed_tensor(d['compressed_indices'], d['plain_indices'], values, - size=d['shape'], layout=d['layout']) + if d["layout"] is torch.sparse_coo: + a = torch.sparse_coo_tensor( + d["indices"], + values, + size=d["shape"], + is_coalesced=d["is_coalesced"], + ) + elif d["layout"] in sparse_compressed_layouts: + a = torch.sparse_compressed_tensor( + d["compressed_indices"], + d["plain_indices"], + values, + size=d["shape"], + layout=d["layout"], + ) else: - raise NotImplementedError(f'conversion of {d["layout"]} strided representation to tensor') + raise NotImplementedError( + f'conversion of {d["layout"]} strided representation to tensor' + ) new_args.append(a) return tuple(new_args) @@ -591,12 +654,25 @@ def as_sparse_gradcheck(gradcheck): # tensors: outputs = func(*restored_args, **kwargs) - strided_outputs = tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,) - strided_outputs = tuple((o.to_dense(masked_grad=masked) - if isinstance(o, torch.Tensor) and o.requires_grad and o.layout in sparse_layouts else o) - for o in strided_outputs) + strided_outputs = ( + tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,) + ) + strided_outputs = tuple( + ( + o.to_dense(masked_grad=masked) + if isinstance(o, torch.Tensor) + and o.requires_grad + and o.layout in sparse_layouts + else o + ) + for o in strided_outputs + ) - return strided_outputs if isinstance(outputs, (list, tuple)) else strided_outputs[0] + return ( + strided_outputs + if isinstance(outputs, (list, tuple)) + else strided_outputs[0] + ) args = (func_wrapper, convert_to_strided_representation(inputs)) diff --git a/torch/sparse/_semi_structured_conversions.py b/torch/sparse/_semi_structured_conversions.py index 141464f7dc76..0828355202b5 100644 --- a/torch/sparse/_semi_structured_conversions.py +++ b/torch/sparse/_semi_structured_conversions.py @@ -342,11 +342,15 @@ def _compute_compressed_swizzled_bitmask(dense): # [0 0 1 1] # reshape tensor to expand tiles into 8-bit vectors - bitmask_binary_representation = bitmask_4x4_chunks.reshape(*bitmask_4x4_chunks.shape[:2], 4, 2, 8) + bitmask_binary_representation = bitmask_4x4_chunks.reshape( + *bitmask_4x4_chunks.shape[:2], 4, 2, 8 + ) # to convert from binary representaiton, we can do a matmul with powers of two - powers_of_two = 2**torch.arange(8, dtype=torch.float, device="cuda") + powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda") # To run on GPU: cast to float to do matmul and then cast back - compressed_swizzled_bitmask = (bitmask_binary_representation.to(torch.float) @ powers_of_two).to(torch.uint8) + compressed_swizzled_bitmask = ( + bitmask_binary_representation.to(torch.float) @ powers_of_two + ).to(torch.uint8) return compressed_swizzled_bitmask diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index e11bdf59c882..7585fc5e3e64 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -1,14 +1,17 @@ # mypy: allow-untyped-defs import math import os -import torch import weakref from functools import lru_cache -from torch.utils._triton import has_triton -from ._triton_ops_meta import get_meta from typing import Optional, Tuple -TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(os.getenv('TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE', 2)) +import torch +from torch.utils._triton import has_triton +from ._triton_ops_meta import get_meta + +TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( + os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2) +) def check(cond, msg): @@ -34,7 +37,7 @@ def check_mm_compatible_shapes(f_name, lhs, rhs): check( lhs.dim() >= 2 and rhs.dim() >= 2, f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, " - f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}." + f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.", ) m, kl = lhs.shape[-2:] @@ -50,7 +53,8 @@ def check_mm_compatible_shapes(f_name, lhs, rhs): def check_dtype(f_name, t, dtype, *additional_dtypes): check( t.dtype == dtype - and t.dtype in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)), + and t.dtype + in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)), f"{f_name}(): all inputs are expected to be of the same dtype " f"and one of (half, bfloat16, float32) or {additional_dtypes}, " f"but got dtype == {t.dtype}.", @@ -140,7 +144,9 @@ def grid_partitioner(full_grid, grid_blocks, tensor_dims_map): yield next(multidim_slicer(t_dims, slices, t)) for grid_point in itertools.product(*generate_grid_points()): - grid = [min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)] + grid = [ + min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks) + ] slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)] # grid_points are iterated in a "contiguous" order, i.e. # left dimensions traversed slower than right dimensions. @@ -166,7 +172,9 @@ def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None): valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid) ) # type: ignore[assignment] - for grid, *sliced_tensors in grid_partitioner(full_grid, grid_blocks, tensor_dims_map): + for grid, *sliced_tensors in grid_partitioner( + full_grid, grid_blocks, tensor_dims_map + ): kernel(grid, *sliced_tensors) @@ -178,7 +186,9 @@ def prepare_inputs(bsr, *dense_tensors): tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors] # Compute broadcasted batch dimension - batch_dims_broadcasted = torch.broadcast_shapes(values.shape[:-3], *(t.shape[:-2] for t in tensors)) + batch_dims_broadcasted = torch.broadcast_shapes( + values.shape[:-3], *(t.shape[:-2] for t in tensors) + ) # Broadcast batch dimensions and squash. # The result can be either a view or a copy. @@ -191,14 +201,13 @@ def prepare_inputs(bsr, *dense_tensors): crow_indices, batch_dims_broadcasted, (-1,) ) - col_indices = batch_broadcast_and_squash( - col_indices, batch_dims_broadcasted, (-1,) - ) + col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,)) values = batch_broadcast_and_squash( values, batch_dims_broadcasted, values.shape[-3:] ) tensors = [ - batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) for t in tensors + batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) + for t in tensors ] return crow_indices, col_indices, values, *tensors @@ -211,7 +220,9 @@ def broadcast_batch_dims_bsr(f_name, bsr, *tensors): col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,)) values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:]) size = batch_shape + bsr.shape[-2:] - return torch.sparse_compressed_tensor(crow_indices, col_indices, values, size=size, layout=bsr.layout) + return torch.sparse_compressed_tensor( + crow_indices, col_indices, values, size=size, layout=bsr.layout + ) # NOTE: this function will ALWAYS create a view @@ -347,7 +358,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): assert blocks.ndim == 3 P, Ms, Ks = blocks.shape - if indices_format == 'scatter_mm': + if indices_format == "scatter_mm": c_offsets, pq = indices_data[1:] assert others.ndim == 3 @@ -356,7 +367,9 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): if accumulators is None: R = c_offsets.shape[0] - 1 - accumulators = torch.zeros((R, Ms, Ns), dtype=blocks.dtype, device=blocks.device) + accumulators = torch.zeros( + (R, Ms, Ns), dtype=blocks.dtype, device=blocks.device + ) else: R, Ms_, Ns_ = accumulators.shape assert Ms_ == Ms @@ -373,7 +386,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): _scatter_mm2(blocks, others, c_offsets, pq, accumulators) return accumulators - elif indices_format == 'bsr_strided_mm': + elif indices_format == "bsr_strided_mm": others_shape = others.shape others = as1Dbatch(others) @@ -381,11 +394,13 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): assert K % Ks == 0 c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:] - SPLIT_N = meta['SPLIT_N'] + SPLIT_N = meta["SPLIT_N"] if accumulators is None: M = Ms + (r_offsets.max().item() + 1) // N - accumulators = torch.zeros((*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device) + accumulators = torch.zeros( + (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device + ) else: M, N_ = accumulators.shape[-2:] assert N_ == N @@ -403,16 +418,25 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): g0 = c_indices[r].item() g1 = c_indices[r + 1].item() r0, r1 = divmod(r_, N) - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] for g in range(g0, g1): p, q = p_offsets[g], q_offsets[g] q0, q1 = divmod(q.item(), N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] else: - _scatter_mm6(blocks, others, c_indices, r_offsets, p_offsets, q_offsets, meta, accumulators) + _scatter_mm6( + blocks, + others, + c_indices, + r_offsets, + p_offsets, + q_offsets, + meta, + accumulators, + ) return accumulators.view(accumulators_shape) - elif indices_format == 'bsr_strided_mm_compressed': + elif indices_format == "bsr_strided_mm_compressed": others_shape = others.shape others = as1Dbatch(others) @@ -420,11 +444,13 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): assert K % Ks == 0 c_indices, r_offsets, q_offsets, meta = indices_data[1:] - SPLIT_N = meta['SPLIT_N'] + SPLIT_N = meta["SPLIT_N"] if accumulators is None: M = Ms + (r_offsets.max().item() + 1) // N - accumulators = torch.zeros((*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device) + accumulators = torch.zeros( + (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device + ) else: M, N_ = accumulators.shape[-2:] assert N_ == N @@ -442,26 +468,53 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): n = r1 // Ns c0 = c_indices[m].item() c1 = c_indices[m + 1].item() - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] for i, p in enumerate(range(c0, c1)): q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item() q0, q1 = divmod(q, N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] else: - p_offsets = torch.empty((0, ), dtype=q_offsets.dtype, device=q_offsets.device) - _scatter_mm6(blocks, others, c_indices, r_offsets, p_offsets, q_offsets, meta, accumulators) + p_offsets = torch.empty( + (0,), dtype=q_offsets.dtype, device=q_offsets.device + ) + _scatter_mm6( + blocks, + others, + c_indices, + r_offsets, + p_offsets, + q_offsets, + meta, + accumulators, + ) return accumulators.view(accumulators_shape) else: raise NotImplementedError(indices_format) -def scatter_mm_meta(M, K, N, Ms, Ks, - GROUP_SIZE=None, TILE_M=None, TILE_N=None, SPLIT_N=None, num_warps=None, num_stages=None, **extra): +def scatter_mm_meta( + M, + K, + N, + Ms, + Ks, + GROUP_SIZE=None, + TILE_M=None, + TILE_N=None, + SPLIT_N=None, + num_warps=None, + num_stages=None, + **extra, +): if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}: device_name = torch.cuda.get_device_name() - meta = get_meta('scatter_mm', (M, K, N, Ms, Ks), device_name, - version=(0, torch.float16, 0.5)) + meta = get_meta( + "scatter_mm", + (M, K, N, Ms, Ks), + device_name, + version=(0, torch.float16, 0.5), + ) if meta is not None: meta.update(**extra) return meta @@ -473,51 +526,156 @@ def scatter_mm_meta(M, K, N, Ms, Ks, # parameters are likely different from what specified below. if (M, K, N) == (256,) * 3: if (Ms, Ks) == (16, 16): - SPLIT_N=1;TILE_M=16;TILE_N=16;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 1 + TILE_M = 16 + TILE_N = 16 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (Ms, Ks) == (32, 32): - SPLIT_N=2;TILE_M=32;TILE_N=16;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 2 + TILE_M = 32 + TILE_N = 16 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (Ms, Ks) == (64, 64): - SPLIT_N=1;TILE_M=32;TILE_N=32;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 1 + TILE_M = 32 + TILE_N = 32 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (Ms, Ks) == (128, 128): - SPLIT_N=1;TILE_M=32;TILE_N=32;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 1 + TILE_M = 32 + TILE_N = 32 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (M, K, N) == (512,) * 3: if (Ms, Ks) == (16, 16): - SPLIT_N=8;TILE_M=16;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702 + SPLIT_N = 8 + TILE_M = 16 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 elif (Ms, Ks) == (32, 32): - SPLIT_N=8;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=2 # noqa: E225,E231,E702 + SPLIT_N = 8 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 elif (Ms, Ks) == (64, 64): - SPLIT_N=4;TILE_M=32;TILE_N=128;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 4 + TILE_M = 32 + TILE_N = 128 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (Ms, Ks) == (128, 128): - SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 8 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (M, K, N) == (1024,) * 3: if (Ms, Ks) == (16, 16): - SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702 + SPLIT_N = 4 + TILE_M = 16 + TILE_N = 128 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 elif (Ms, Ks) == (32, 32): - SPLIT_N=8;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702 + SPLIT_N = 8 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 elif (Ms, Ks) == (64, 64): - SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=2 # noqa: E225,E231,E702 + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 elif (Ms, Ks) == (128, 128): - SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (Ms, Ks) == (256, 256): - SPLIT_N=16;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (M, K, N) == (2048,) * 3: if (Ms, Ks) == (16, 16): - SPLIT_N=4;TILE_M=16;TILE_N=128;GROUP_SIZE=8;num_stages=1;num_warps=1 # noqa: E225,E231,E702 + SPLIT_N = 4 + TILE_M = 16 + TILE_N = 128 + GROUP_SIZE = 8 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 elif (Ms, Ks) == (32, 32): - SPLIT_N=4;TILE_M=32;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=1 # noqa: E225,E231,E702 + SPLIT_N = 4 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 elif (Ms, Ks) == (64, 64): - SPLIT_N=4;TILE_M=64;TILE_N=128;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 4 + TILE_M = 64 + TILE_N = 128 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (Ms, Ks) == (128, 128): - SPLIT_N=8;TILE_M=64;TILE_N=64;GROUP_SIZE=4;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 8 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (Ms, Ks) == (256, 256): - SPLIT_N=4;TILE_M=64;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 4 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 elif (M, K, N) == (4096,) * 3: if (Ms, Ks) == (16, 16): - SPLIT_N=2;TILE_M=16;TILE_N=256;GROUP_SIZE=2;num_stages=1;num_warps=2 # noqa: E225,E231,E702 + SPLIT_N = 2 + TILE_M = 16 + TILE_N = 256 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 elif (Ms, Ks) == (32, 32): - SPLIT_N=2;TILE_M=32;TILE_N=64;GROUP_SIZE=2;num_stages=1;num_warps=1 # noqa: E225,E231,E702 + SPLIT_N = 2 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 elif (Ms, Ks) == (64, 64): - SPLIT_N=2;TILE_M=64;TILE_N=128;GROUP_SIZE=2;num_stages=1;num_warps=4 # noqa: E225,E231,E702 + SPLIT_N = 2 + TILE_M = 64 + TILE_N = 128 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 if SPLIT_N is None: # Assume NVIDIA GeForce RTX 2060 SUPER: @@ -526,7 +684,17 @@ def scatter_mm_meta(M, K, N, Ms, Ks, # performance when using an optimal value. Otherwise, when N # <= 512, using the following heuristics may give upto 15% # lower performance. - SPLIT_N = {16: 1, 32: 2, 64: 4, 128: 8, 256: 16, 512: 8, 1024: 16, 4096: 32, 8192: 64}.get(N, 16) + SPLIT_N = { + 16: 1, + 32: 2, + 64: 4, + 128: 8, + 256: 16, + 512: 8, + 1024: 16, + 4096: 32, + 8192: 64, + }.get(N, 16) if Ms >= 512 and N >= 2048: SPLIT_N = 1 Ns = N // SPLIT_N @@ -552,12 +720,33 @@ def scatter_mm_meta(M, K, N, Ms, Ks, assert Ns <= N, dict(N=N, Ns=Ns) assert Ks <= K, dict(K=K, Ks=Ks) - return dict(TILE_M=TILE_M, TILE_N=TILE_N, GROUP_SIZE=GROUP_SIZE, - num_stages=num_stages, num_warps=num_warps, SPLIT_N=SPLIT_N, **extra) + return dict( + TILE_M=TILE_M, + TILE_N=TILE_N, + GROUP_SIZE=GROUP_SIZE, + num_stages=num_stages, + num_warps=num_warps, + SPLIT_N=SPLIT_N, + **extra, + ) -def bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, - SPLIT_N=None, GROUP_SIZE_ROW=None, num_warps=None, num_stages=None, sparsity=None, dtype=None, **extra): +def bsr_dense_addmm_meta( + M, + K, + N, + Ms, + Ks, + beta, + alpha, + SPLIT_N=None, + GROUP_SIZE_ROW=None, + num_warps=None, + num_stages=None, + sparsity=None, + dtype=None, + **extra, +): if dtype is None: dtype = torch.float16 if sparsity is None: @@ -565,20 +754,24 @@ def bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: device_name = torch.cuda.get_device_name() key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) - meta = get_meta('bsr_dense_addmm', key, - device_name, version=(0, dtype, sparsity)) + meta = get_meta( + "bsr_dense_addmm", key, device_name, version=(0, dtype, sparsity) + ) if meta is None and sparsity != 0.5: - meta = get_meta('bsr_dense_addmm', key, - device_name, version=(0, dtype, 0.5)) + meta = get_meta( + "bsr_dense_addmm", key, device_name, version=(0, dtype, 0.5) + ) if meta is None: # find approximate meta such that N % SPLIT_N == 0. matching_meta = get_meta( - 'bsr_dense_addmm', - (*key[:2], '*', *key[3:]), - device_name, version=(0, dtype, 0.5)) + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(0, dtype, 0.5), + ) for mkey in sorted(matching_meta or {}): meta_ = matching_meta[mkey] - if N % meta_['SPLIT_N'] == 0 and mkey[2] <= N: + if N % meta_["SPLIT_N"] == 0 and mkey[2] <= N: meta = meta_ if meta is not None: meta.update(**extra) @@ -587,7 +780,13 @@ def bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4 num_stages = num_stages or 1 num_warps = num_warps or 4 - return dict(SPLIT_N=SPLIT_N, GROUP_SIZE_ROW=GROUP_SIZE_ROW, num_stages=num_stages, num_warps=num_warps, **extra) + return dict( + SPLIT_N=SPLIT_N, + GROUP_SIZE_ROW=GROUP_SIZE_ROW, + num_stages=num_stages, + num_warps=num_warps, + **extra, + ) class TensorAsKey: @@ -614,7 +813,6 @@ class TensorAsKey: """ def __init__(self, obj): - def get_tensor_key(obj): # Warning: TensorAsKey does not track negative nor # conjugate bits of its input object because in the use @@ -626,15 +824,27 @@ class TensorAsKey: # and is_conj methods) must be included in the key as # well. assert not (obj.dtype.is_floating_point or obj.dtype.is_complex), obj.dtype - return (obj.data_ptr(), obj.storage_offset(), obj.shape, obj.stride(), obj.dtype) + return ( + obj.data_ptr(), + obj.storage_offset(), + obj.shape, + obj.stride(), + obj.dtype, + ) self._obj_ref = weakref.ref(obj) if obj.layout is torch.strided: self.key = get_tensor_key(obj) elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: - self.key = (get_tensor_key(obj.crow_indices()), get_tensor_key(obj.col_indices())) + self.key = ( + get_tensor_key(obj.crow_indices()), + get_tensor_key(obj.col_indices()), + ) elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}: - self.key = (get_tensor_key(obj.ccol_indices()), get_tensor_key(obj.row_indices())) + self.key = ( + get_tensor_key(obj.ccol_indices()), + get_tensor_key(obj.row_indices()), + ) else: raise NotImplementedError(obj.layout) self._hash = hash(self.key) @@ -658,14 +868,16 @@ class TensorAsKey: @lru_cache(maxsize=TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE) -def _bsr_scatter_mm_indices_data(indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key): +def _bsr_scatter_mm_indices_data( + indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key +): bsr = compressed_sparse_tensor_as_key.obj assert bsr is not None crow_indices, col_indices = bsr.crow_indices(), bsr.col_indices() device = crow_indices.device indices_dtype = torch.int32 - if indices_format == 'bsr_strided_mm_compressed': + if indices_format == "bsr_strided_mm_compressed": Ns = N // SPLIT_N q_offsets_lst = [] b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns @@ -674,7 +886,10 @@ def _bsr_scatter_mm_indices_data(indices_format, M, K, N, Ms, Ks, nbatches, SPLI r1 = crow_indices[m + 1].item() if r1 == r0: continue - q_offsets_lst.append((col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + b.repeat_interleave(r1 - r0)) + q_offsets_lst.append( + (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + + b.repeat_interleave(r1 - r0) + ) q_offsets = torch.cat(q_offsets_lst) crow_indices_diff = crow_indices.diff() non_zero_row_indices = crow_indices_diff.nonzero() @@ -687,7 +902,7 @@ def _bsr_scatter_mm_indices_data(indices_format, M, K, N, Ms, Ks, nbatches, SPLI r_offsets = r_offsets[indices] return (indices_format, c_indices, r_offsets, q_offsets) - elif indices_format == 'bsr_strided_mm': + elif indices_format == "bsr_strided_mm": Ns = N // SPLIT_N p_offsets_lst = [] q_offsets_lst = [] @@ -697,19 +912,31 @@ def _bsr_scatter_mm_indices_data(indices_format, M, K, N, Ms, Ks, nbatches, SPLI r1 = crow_indices[m + 1].item() if r1 == r0: continue - p_offsets_lst.append(torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N)) - q_offsets_lst.append((col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + b.repeat_interleave(r1 - r0)) + p_offsets_lst.append( + torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N) + ) + q_offsets_lst.append( + (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + + b.repeat_interleave(r1 - r0) + ) q_offsets = torch.cat(q_offsets_lst) crow_indices_diff = crow_indices.diff() non_zero_row_indices = crow_indices_diff.nonzero() a = non_zero_row_indices * (Ms * N) r_offsets = (a + b).view(-1) - c_indices = torch.cat((crow_indices[:1], - torch.cumsum(crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N), 0))) + c_indices = torch.cat( + ( + crow_indices[:1], + torch.cumsum( + crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N), + 0, + ), + ) + ) p_offsets = torch.cat(p_offsets_lst) return (indices_format, c_indices, r_offsets, p_offsets, q_offsets) - elif indices_format == 'scatter_mm': + elif indices_format == "scatter_mm": Ns = Ms c_indices = [0] pq_offsets = [] @@ -725,15 +952,21 @@ def _bsr_scatter_mm_indices_data(indices_format, M, K, N, Ms, Ks, nbatches, SPLI q = (col_indices[p].item() + b * (K // Ks)) * (N // Ns) + n pq_offsets.append([p, q]) - return (indices_format, - torch.tensor(c_indices, dtype=indices_dtype, device=device), - torch.tensor(pq_offsets, dtype=indices_dtype, device=device)) + return ( + indices_format, + torch.tensor(c_indices, dtype=indices_dtype, device=device), + torch.tensor(pq_offsets, dtype=indices_dtype, device=device), + ) else: - raise ValueError(f'Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm') + raise ValueError( + f"Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm" + ) -def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compressed', **meta_input): +def bsr_scatter_mm_indices_data( + bsr, other, indices_format="bsr_strided_mm_compressed", **meta_input +): """Computes indices data for :func:`scatter_mm` used in BSR and strided tensor matrix multiplication. """ @@ -749,16 +982,17 @@ def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compr nbatches = other.shape[:-2].numel() meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input) - if 'allow_tf32' not in meta_input: + if "allow_tf32" not in meta_input: meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16}) - SPLIT_N = meta['SPLIT_N'] + SPLIT_N = meta["SPLIT_N"] indices_data = _bsr_scatter_mm_indices_data( - indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr)) + indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr) + ) - if indices_format == 'bsr_strided_mm_compressed': + if indices_format == "bsr_strided_mm_compressed": meta.update(is_compressed=True) return indices_data + (meta,) - elif indices_format == 'bsr_strided_mm': + elif indices_format == "bsr_strided_mm": meta.update(is_compressed=False) return indices_data + (meta,) else: @@ -766,8 +1000,7 @@ def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compr def bsr_scatter_mm(bsr, other, indices_data=None, out=None): - """BSR @ strided -> strided - """ + """BSR @ strided -> strided""" assert bsr.ndim == 2 assert other.ndim >= 2 @@ -776,36 +1009,61 @@ def bsr_scatter_mm(bsr, other, indices_data=None, out=None): blocksize = bsr.values().shape[-2:] if indices_data is None: - indices_data = bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compressed') + indices_data = bsr_scatter_mm_indices_data( + bsr, other, indices_format="bsr_strided_mm_compressed" + ) indices_format = indices_data[0] if out is None: - out = torch.empty((*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device) + out = torch.empty( + (*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device + ) out_shape = out.shape out = as1Dbatch(out) if bsr._nnz() == 0: out.zero_() - elif indices_format in {'bsr_strided_mm_compressed', 'bsr_strided_mm'}: + elif indices_format in {"bsr_strided_mm_compressed", "bsr_strided_mm"}: out.zero_() scatter_mm(bsr.values(), other, indices_data, accumulators=out) - elif indices_format == 'scatter_mm': + elif indices_format == "scatter_mm": nbatches = other.shape[:-2].numel() - accumulators = torch.zeros((nbatches * Ms // blocksize[0] * Ns // blocksize[0], blocksize[0], blocksize[0]), - dtype=bsr.dtype, device=bsr.device) - others = (as1Dbatch(other) - .transpose(-2, -1) - .view(nbatches, Ns // blocksize[0], blocksize[0], Ks // blocksize[1], blocksize[1]) - .movedim((3, 1, 4, 2), (1, 2, 3, 4)) # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3) - .flatten(0, 2) - ) + accumulators = torch.zeros( + ( + nbatches * Ms // blocksize[0] * Ns // blocksize[0], + blocksize[0], + blocksize[0], + ), + dtype=bsr.dtype, + device=bsr.device, + ) + others = ( + as1Dbatch(other) + .transpose(-2, -1) + .view( + nbatches, + Ns // blocksize[0], + blocksize[0], + Ks // blocksize[1], + blocksize[1], + ) + .movedim( + (3, 1, 4, 2), (1, 2, 3, 4) + ) # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3) + .flatten(0, 2) + ) scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators) - out.copy_(accumulators - .unflatten(0, (nbatches, Ms // blocksize[0], Ns // blocksize[0])) - .movedim((1, 2, 3, 4), (3, 1, 4, 2)) # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2) - .reshape(nbatches, Ns, Ms) - .transpose(-2, -1)) + out.copy_( + accumulators.unflatten( + 0, (nbatches, Ms // blocksize[0], Ns // blocksize[0]) + ) + .movedim( + (1, 2, 3, 4), (3, 1, 4, 2) + ) # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2) + .reshape(nbatches, Ns, Ms) + .transpose(-2, -1) + ) else: raise NotImplementedError(indices_format) @@ -813,23 +1071,24 @@ def bsr_scatter_mm(bsr, other, indices_data=None, out=None): def bsr_dense_addmm( - input: torch.Tensor, - bsr: torch.Tensor, - dense: torch.Tensor, - *, - beta=1, - alpha=1, - out: Optional[torch.Tensor] = None, - skip_checks: bool = False, - max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, - meta: Optional[dict] = None): - f_name = 'bsr_dense_addmm' + input: torch.Tensor, + bsr: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + f_name = "bsr_dense_addmm" values = bsr.values() crow_indices = bsr.crow_indices() col_indices = bsr.col_indices() batch_ndim = crow_indices.dim() - 1 - M, K = bsr.shape[batch_ndim:batch_ndim + 2] - blocksize = values.shape[batch_ndim + 1:batch_ndim + 3] + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3] N = dense.shape[-1] # todo: implement checks @@ -849,13 +1108,25 @@ def bsr_dense_addmm( if meta is None: sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) - meta = bsr_dense_addmm_meta(M, K, N, blocksize[0], blocksize[1], beta, alpha, sparsity=sparsity, dtype=out.dtype) + meta = bsr_dense_addmm_meta( + M, + K, + N, + blocksize[0], + blocksize[1], + beta, + alpha, + sparsity=sparsity, + dtype=out.dtype, + ) out_backup = out - crow_indices, col_indices, values, input, dense, out = prepare_inputs(bsr, input, dense, out) + crow_indices, col_indices, values, input, dense, out = prepare_inputs( + bsr, input, dense, out + ) BM, BK = blocksize - SPLIT_N = meta.get('SPLIT_N', N // BM) + SPLIT_N = meta.get("SPLIT_N", N // BM) BN = N // SPLIT_N out_untiled = out @@ -863,10 +1134,12 @@ def bsr_dense_addmm( dense = tile_to_blocksize(dense, (BK, BN)) input = tile_to_blocksize(input, (BM, BN)) - dot_out_dtype = {torch.float16: tl.float32, - torch.bfloat16: tl.float32, - torch.float32: tl.float64, - torch.float64: tl.float64}[out.dtype] + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + }[out.dtype] n_batches = dense.size(0) n_block_rows = crow_indices.size(-1) - 1 @@ -892,7 +1165,8 @@ def bsr_dense_addmm( def kernel(grid, *sliced_tensors): _bsr_strided_addmm_kernel[grid]( *ptr_stride_extractor(*sliced_tensors), - beta, alpha, + beta, + alpha, beta_is_one=beta == 1, beta_is_nonzero=beta != 0, alpha_is_one=alpha == 1, @@ -901,7 +1175,8 @@ def bsr_dense_addmm( BLOCKSIZE_COL=BN, allow_tf32=dot_out_dtype == tl.float32, acc_dtype=dot_out_dtype, - **meta) + **meta, + ) launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) @@ -1014,19 +1289,22 @@ if has_triton(): mask_k = k_offsets < k mat1_block = tl.load( - mat1_block_ptrs - + mat1_col_block_stride * k_offsets[None, :], - mask=mask_k[None, :], other=0.0 + mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :], + mask=mask_k[None, :], + other=0.0, ) mat2_block = tl.load( mat2_block_ptrs + mat2_tiled_col_stride * col_block + mat2_row_block_stride * k_offsets[:, None], - mask=mask_k[:, None], other=0.0 + mask=mask_k[:, None], + other=0.0, ) - acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype) + acc_block += tl.dot( + mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) if IS_BETA_ZERO: acc_block *= alpha @@ -1156,10 +1434,14 @@ if has_triton(): # find which row of dense needs to get loaded # for multiplication with values_block. dense_row_idx = tl.load(col_index_nnz_ptr) - dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx + ) # do block mm - output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype) + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) # move val/col_index ptrs to the next block in the row values_block_ptrs += values_nnz_stride @@ -1168,13 +1450,19 @@ if has_triton(): # write back the result tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) - def _run_sampled_addmm_kernel( - alpha, beta, is_beta_zero, - blocksize, k, tile_k, - values, crow_indices, col_indices, - mat1, mat2, - max_grid + alpha, + beta, + is_beta_zero, + blocksize, + k, + tile_k, + values, + crow_indices, + col_indices, + mat1, + mat2, + max_grid, ): n_batches = values.size(0) n_block_rows = crow_indices.size(-1) - 1 @@ -1200,18 +1488,21 @@ if has_triton(): def kernel(grid, *sliced_tensors): _sampled_addmm_kernel[grid]( - alpha, beta, is_beta_zero, - *blocksize, k, tile_k, + alpha, + beta, + is_beta_zero, + *blocksize, + k, + tile_k, *ptr_stride_extractor(*sliced_tensors), acc_dtype=acc_dtype, allow_tf32=allow_tf32, num_stages=1, - num_warps=4 + num_warps=4, ) launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) - def sampled_addmm( input: torch.Tensor, mat1: torch.Tensor, @@ -1234,7 +1525,7 @@ if has_triton(): if beta != 0.0 and input.dtype is torch.bool: check( False, - f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed." + f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.", ) if input.dtype is not torch.bool: check_dtype(f_name, mat1, input.dtype) @@ -1247,11 +1538,10 @@ if has_triton(): check_device(f_name, out, mat1.device) check_dtype(f_name, out, input.dtype) check( - out.shape == input_broadcasted.shape - and out._nnz() == input._nnz(), + out.shape == input_broadcasted.shape and out._nnz() == input._nnz(), f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} " f"and with nnz equal to {input_broadcasted._nnz()} " - f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}" + f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}", ) if out is None: @@ -1281,11 +1571,18 @@ if has_triton(): tile_k = max(*blocksize) _run_sampled_addmm_kernel( - alpha, beta, beta == 0.0, - blocksize, k, tile_k, - values, crow_indices, col_indices, - mat1, mat2, - max_grid + alpha, + beta, + beta == 0.0, + blocksize, + k, + tile_k, + values, + crow_indices, + col_indices, + mat1, + mat2, + max_grid, ) # If nnz x block strides are not the same in out_backup.values and values, @@ -1295,7 +1592,6 @@ if has_triton(): out_backup.values().copy_(values.reshape(out_backup.values().shape)) return out_backup - def bsr_dense_mm( bsr: torch.Tensor, dense: torch.Tensor, @@ -1303,7 +1599,7 @@ if has_triton(): out: Optional[torch.Tensor] = None, skip_checks: bool = False, max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, - meta: Optional[dict] = None + meta: Optional[dict] = None, ): f_name = "bsr_dense_mm" m, kl = bsr.shape[-2:] @@ -1318,7 +1614,7 @@ if has_triton(): check_blocksize(f_name, (row_block, col_block)) check( not n % 16, - f"{f_name}(): dense.size(-1) == {n} should be divisible by 16" + f"{f_name}(): dense.size(-1) == {n} should be divisible by 16", ) else: kr, n = dense.shape[-2:] @@ -1351,7 +1647,6 @@ if has_triton(): # as a placeholder for input because their shapes match: return bsr_dense_addmm(out, bsr, dense, alpha=1, beta=0, out=out) - @triton.jit def _bsr_softmax_kernel( crow_indices_ptr, @@ -1361,9 +1656,10 @@ if has_triton(): values_batch_stride, values_row_block_stride, values_nnz_col_block_stride, - row_block, col_block, + row_block, + col_block, MAX_ROW_NNZ: tl.constexpr, - TILE: tl.constexpr + TILE: tl.constexpr, ): batch_pid = tl.program_id(axis=2) row_block_offset_pid = tl.program_id(axis=1) @@ -1394,14 +1690,20 @@ if has_triton(): ) # find max in the row - row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32) + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) max_row_value = tl.max(row_tile, axis=0) for _ in range(TILE, MAX_ROW_NNZ, TILE): row_arange += TILE mask = row_arange < row_nnz * col_block - row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32) + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) curr_max_row_value = tl.max(row_tile, axis=0) - max_row_value = tl.where(max_row_value > curr_max_row_value, max_row_value, curr_max_row_value) + max_row_value = tl.where( + max_row_value > curr_max_row_value, max_row_value, curr_max_row_value + ) # find denominator for stable softmax num = tl.exp(row_tile - max_row_value) @@ -1409,19 +1711,30 @@ if has_triton(): for _ in range(TILE, MAX_ROW_NNZ, TILE): row_arange -= TILE mask = row_arange < row_nnz * col_block - row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32) + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) num = tl.exp(row_tile - max_row_value) denom += tl.sum(num, axis=0) # populate output - tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask) + tl.store( + curr_row_values_ptrs + row_arange, + (num / denom).to(values_ptr.dtype.element_ty), + mask=mask, + ) for _ in range(TILE, MAX_ROW_NNZ, TILE): row_arange += TILE mask = row_arange < row_nnz * col_block - row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32) + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) num = tl.exp(row_tile - max_row_value) - tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask) - + tl.store( + curr_row_values_ptrs + row_arange, + (num / denom).to(values_ptr.dtype.element_ty), + mask=mask, + ) def bsr_softmax(input, max_row_nnz=None): f_name = "bsr_softmax" @@ -1452,7 +1765,13 @@ if has_triton(): values = input.values().clone() else: values = input.values() - values = values.transpose(-3, -2).contiguous().unsqueeze(0).flatten(0, -4).reshape(-1, row_block, nnz * col_block) + values = ( + values.transpose(-3, -2) + .contiguous() + .unsqueeze(0) + .flatten(0, -4) + .reshape(-1, row_block, nnz * col_block) + ) full_grid = (values.shape[0], row_block, m // row_block) grid_blocks = None tensor_dims_map = { @@ -1465,22 +1784,27 @@ if has_triton(): def kernel(grid, *sliced_tensors): _bsr_softmax_kernel[grid]( *ptr_stride_extractor(*sliced_tensors), - row_block, col_block, + row_block, + col_block, max_row_nnz, # Triton's max numel is bounded by 2 ** 17. - min(2 ** 17, max_row_nnz) + min(2**17, max_row_nnz), ) launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) - values = values.reshape(-1, row_block, nnz, col_block).transpose(-3, -2).reshape(*input.values().shape) + values = ( + values.reshape(-1, row_block, nnz, col_block) + .transpose(-3, -2) + .reshape(*input.values().shape) + ) return torch.sparse_compressed_tensor( input.crow_indices().clone(), input.col_indices().clone(), values, size=input.shape, - layout=input.layout + layout=input.layout, ) def _scaled_dot_product_attention( @@ -1490,24 +1814,18 @@ if has_triton(): attn_mask: Optional[torch.Tensor], dropout_p: float = 0.0, is_causal: bool = False, - scale: Optional[float] = None + scale: Optional[float] = None, ): f_name = "_scaled_dot_product_attention" - check( - not is_causal, - f"{f_name}(): is_causal == True is not supported." - ) - check( - attn_mask is not None, - f"{f_name}(): attn_mask == None is not supported." - ) + check(not is_causal, f"{f_name}(): is_causal == True is not supported.") + check(attn_mask is not None, f"{f_name}(): attn_mask == None is not supported.") assert attn_mask is not None check( attn_mask.layout == torch.sparse_bsr, f"{f_name}(): " f"attn_mask.layout must be {torch.sparse_bsr}, but got " - f"attn_mask.layout == {attn_mask.layout}." + f"attn_mask.layout == {attn_mask.layout}.", ) check_device(f_name, key, query.device) @@ -1519,12 +1837,14 @@ if has_triton(): if attn_mask.dtype is not torch.bool: check_dtype(f_name, attn_mask, query.dtype) - sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False) + sdpa = sampled_addmm( + attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False + ) if scale is None and query.size(-1) == 0 or scale == 0.0: check( False, f"{f_name}(): current value of scale == {scale} " - "results in division by zero." + "results in division by zero.", ) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale sdpa.values().mul_(scale_factor) @@ -1535,17 +1855,31 @@ if has_triton(): @triton.jit def _scatter_mm2_kernel( - M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, - blocks_ptr, blocks_stride_P, blocks_stride_M, blocks_stride_K, - others_ptr, others_stride_Q, others_stride_K, others_stride_N, - accumulators_ptr, accumulators_stride_R, accumulators_stride_M, accumulators_stride_N, - pq_offsets_ptr, pq_offsets_stride, - pq_ptr, pq_stride_T, pq_stride_1, - dot_out_dtype: tl.constexpr, - TILE_M: tl.constexpr, - TILE_N: tl.constexpr, - allow_tf32: tl.constexpr): - + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + blocks_ptr, + blocks_stride_P, + blocks_stride_M, + blocks_stride_K, + others_ptr, + others_stride_Q, + others_stride_K, + others_stride_N, + accumulators_ptr, + accumulators_stride_R, + accumulators_stride_M, + accumulators_stride_N, + pq_offsets_ptr, + pq_offsets_stride, + pq_ptr, + pq_stride_T, + pq_stride_1, + dot_out_dtype: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + allow_tf32: tl.constexpr, + ): Ms = M // TILE_M Ns = N // TILE_N @@ -1555,12 +1889,16 @@ if has_triton(): pid_m = pid // Ms pid_n = pid % Ms - rm = (pid_m * TILE_M + tl.arange(0, TILE_M)) - rn = (pid_n * TILE_N + tl.arange(0, TILE_N)) + rm = pid_m * TILE_M + tl.arange(0, TILE_M) + rn = pid_n * TILE_N + tl.arange(0, TILE_N) rk = tl.arange(0, K) - A_ptr = blocks_ptr + (rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K) - B_ptr = others_ptr + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N) + A_ptr = blocks_ptr + ( + rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K + ) + B_ptr = others_ptr + ( + rk[:, None] * others_stride_K + rn[None, :] * others_stride_N + ) g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride) g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride) @@ -1577,58 +1915,101 @@ if has_triton(): B = tl.load(B_ptr + q * others_stride_Q) acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) - C_ptr = accumulators_ptr + pid_t * accumulators_stride_R + ( - rm[:, None] * accumulators_stride_M + rn[None, :] * accumulators_stride_N) + C_ptr = ( + accumulators_ptr + + pid_t * accumulators_stride_R + + ( + rm[:, None] * accumulators_stride_M + + rn[None, :] * accumulators_stride_N + ) + ) tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) def _scatter_mm2( - blocks: torch.Tensor, - others: torch.Tensor, - pq_offsets: torch.Tensor, - pq_indices: torch.Tensor, - accumulators: torch.Tensor + blocks: torch.Tensor, + others: torch.Tensor, + pq_offsets: torch.Tensor, + pq_indices: torch.Tensor, + accumulators: torch.Tensor, ): P, M, K = blocks.shape Q, _, N = others.shape R, _, _ = accumulators.shape - meta = dict(TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2) + meta = dict( + TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2 + ) def grid(META): - return (pq_offsets.shape[0] - 1, triton.cdiv(M, META['TILE_M']) * triton.cdiv(N, META['TILE_N']), 1) + return ( + pq_offsets.shape[0] - 1, + triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), + 1, + ) - dot_out_dtype = {torch.float16: tl.float32, - torch.bfloat16: tl.float32, - torch.float32: tl.float64, - torch.float64: tl.float64}[accumulators.dtype] - if 'allow_tf32' not in meta: + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + }[accumulators.dtype] + if "allow_tf32" not in meta: meta.update(allow_tf32=dot_out_dtype == tl.float32) _scatter_mm2_kernel[grid]( - M, K, N, - blocks, blocks.stride(0), blocks.stride(1), blocks.stride(2), - others, others.stride(0), others.stride(1), others.stride(2), - accumulators, accumulators.stride(0), accumulators.stride(1), accumulators.stride(2), - pq_offsets, pq_offsets.stride(0), - pq_indices, pq_indices.stride(0), pq_indices.stride(1), + M, + K, + N, + blocks, + blocks.stride(0), + blocks.stride(1), + blocks.stride(2), + others, + others.stride(0), + others.stride(1), + others.stride(2), + accumulators, + accumulators.stride(0), + accumulators.stride(1), + accumulators.stride(2), + pq_offsets, + pq_offsets.stride(0), + pq_indices, + pq_indices.stride(0), + pq_indices.stride(1), dot_out_dtype=dot_out_dtype, - **meta + **meta, ) @triton.jit def _scatter_mm6_kernel( - nbatches, Ms, Ks: tl.constexpr, N, - blocks_ptr, blocks_stride_P, blocks_stride_M, blocks_stride_K, - others_ptr, others_stride_B, others_stride_K, others_stride_N, - accumulators_ptr, accumulators_stride_B, accumulators_stride_M, accumulators_stride_N, - c_indices_ptr, r_offsets_ptr, - p_offsets_ptr, q_offsets_ptr, - is_compressed: tl.constexpr, - dot_out_dtype: tl.constexpr, - SPLIT_N: tl.constexpr, - TILE_M: tl.constexpr, - TILE_N: tl.constexpr, - GROUP_SIZE: tl.constexpr, - allow_tf32: tl.constexpr): + nbatches, + Ms, + Ks: tl.constexpr, + N, + blocks_ptr, + blocks_stride_P, + blocks_stride_M, + blocks_stride_K, + others_ptr, + others_stride_B, + others_stride_K, + others_stride_N, + accumulators_ptr, + accumulators_stride_B, + accumulators_stride_M, + accumulators_stride_N, + c_indices_ptr, + r_offsets_ptr, + p_offsets_ptr, + q_offsets_ptr, + is_compressed: tl.constexpr, + dot_out_dtype: tl.constexpr, + SPLIT_N: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + allow_tf32: tl.constexpr, + ): Ns = N // SPLIT_N BLOCKS_M = Ms // TILE_M BLOCKS_N = Ns // TILE_N @@ -1645,11 +2026,17 @@ if has_triton(): pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m - rm = (pid_m * TILE_M + tl.arange(0, TILE_M)) - rn = (pid_n * TILE_N + tl.arange(0, TILE_N)) + rm = pid_m * TILE_M + tl.arange(0, TILE_M) + rn = pid_n * TILE_N + tl.arange(0, TILE_N) rk = tl.arange(0, Ks) - A_ptr = blocks_ptr + (rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K) - B_ptr = others_ptr + pid_b * others_stride_B + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N) + A_ptr = blocks_ptr + ( + rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K + ) + B_ptr = ( + others_ptr + + pid_b * others_stride_B + + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N) + ) # When is_compressed is True, r is the only variable that # depends on pid_t. This property allows sorting r values @@ -1678,7 +2065,9 @@ if has_triton(): q = tl.load(q_ptr) B = tl.load(B_ptr + q) A = tl.load(A_ptr) - acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + acc_block += tl.dot( + A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 + ) A_ptr += blocks_stride_P q_ptr += 1 else: @@ -1690,24 +2079,33 @@ if has_triton(): A = tl.load(A_ptr + p * blocks_stride_P) p_ptr += 1 q_ptr += 1 - acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + acc_block += tl.dot( + A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 + ) - C_ptr = accumulators_ptr + r + pid_b * accumulators_stride_B + ( - rm[:, None] * accumulators_stride_M + rn[None, :] * accumulators_stride_N) + C_ptr = ( + accumulators_ptr + + r + + pid_b * accumulators_stride_B + + ( + rm[:, None] * accumulators_stride_M + + rn[None, :] * accumulators_stride_N + ) + ) tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) def _scatter_mm6( - blocks: torch.Tensor, - others: torch.Tensor, - c_indices: torch.Tensor, - r_offsets: torch.Tensor, - p_offsets: torch.Tensor, - q_offsets: torch.Tensor, - meta: dict, - accumulators: torch.Tensor, - force_contiguous: bool = True, + blocks: torch.Tensor, + others: torch.Tensor, + c_indices: torch.Tensor, + r_offsets: torch.Tensor, + p_offsets: torch.Tensor, + q_offsets: torch.Tensor, + meta: dict, + accumulators: torch.Tensor, + force_contiguous: bool = True, ): - SPLIT_N = meta['SPLIT_N'] + SPLIT_N = meta["SPLIT_N"] P, Ms, Ks = blocks.shape B, K_, N = others.shape B_, M, N_ = accumulators.shape @@ -1716,13 +2114,18 @@ if has_triton(): assert B_ == B def grid(META): - return (r_offsets.shape[0] * B, triton.cdiv(Ms, META['TILE_M']) * triton.cdiv(Ns, META['TILE_N'])) + return ( + r_offsets.shape[0] * B, + triton.cdiv(Ms, META["TILE_M"]) * triton.cdiv(Ns, META["TILE_N"]), + ) - dot_out_dtype = {torch.float16: tl.float32, - torch.bfloat16: tl.float32, - torch.float32: tl.float64, - torch.float64: tl.float64}[accumulators.dtype] - if 'allow_tf32' not in meta: + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + }[accumulators.dtype] + if "allow_tf32" not in meta: meta.update(allow_tf32=dot_out_dtype == tl.float32) assert c_indices.stride(0) == 1 @@ -1754,16 +2157,28 @@ if has_triton(): accumulators_ = accumulators _scatter_mm6_kernel[grid]( - B, Ms, Ks, N, - blocks, blocks.stride(0), blocks.stride(1), blocks.stride(2), - others, others.stride(0), others.stride(1), others.stride(2), - accumulators_, accumulators_.stride(0), accumulators_.stride(1), accumulators_.stride(2), + B, + Ms, + Ks, + N, + blocks, + blocks.stride(0), + blocks.stride(1), + blocks.stride(2), + others, + others.stride(0), + others.stride(1), + others.stride(2), + accumulators_, + accumulators_.stride(0), + accumulators_.stride(1), + accumulators_.stride(2), c_indices, r_offsets, p_offsets, q_offsets, dot_out_dtype=dot_out_dtype, - **meta + **meta, ) if force_contiguous and not accumulators.is_contiguous(): @@ -1823,9 +2238,8 @@ if has_triton(): acc_dtype: tl.constexpr, allow_tf32: tl.constexpr, GROUP_SIZE_ROW: tl.constexpr, - SPLIT_N: tl.constexpr + SPLIT_N: tl.constexpr, ): - batch_pid = tl.program_id(axis=2) row_block_pid = tl.program_id(axis=0) col_block_pid = tl.program_id(axis=1) @@ -1913,10 +2327,14 @@ if has_triton(): # find which row of dense needs to get loaded # for multiplication with values_block. dense_row_idx = tl.load(col_index_nnz_ptr) - dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx + ) # do block mm - output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype) + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) # move val/col_index ptrs to the next block in the row values_block_ptrs += values_nnz_stride @@ -1928,7 +2346,6 @@ if has_triton(): # write back the result tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) - else: bsr_softmax = None # type: ignore[assignment] bsr_dense_mm = None # type: ignore[assignment] diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 6105038e4df7..23193021232a 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -1,23 +1,23 @@ # mypy: allow-untyped-defs import warnings from collections import namedtuple -from typing import Any, Optional, Tuple, List, Callable, Dict +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch.sparse._semi_structured_conversions import ( sparse_semi_structured_from_dense_cutlass, - sparse_semi_structured_to_dense_cutlass + sparse_semi_structured_to_dense_cutlass, ) from torch.sparse._semi_structured_ops import ( fallback_dispatcher, - semi_sparse_values, - semi_sparse_indices, - semi_sparse_detach, - semi_sparse_t, - semi_sparse_view, - semi_sparse_mm, semi_sparse_addmm, + semi_sparse_detach, + semi_sparse_indices, semi_sparse_linear, + semi_sparse_mm, + semi_sparse_t, + semi_sparse_values, + semi_sparse_view, ) __all__ = [ @@ -175,7 +175,7 @@ class SparseSemiStructuredTensor(torch.Tensor): def __tensor_unflatten__( cls, inner_tensors, - tensor_meta : Tuple[torch.Size, bool, int, bool], + tensor_meta: Tuple[torch.Size, bool, int, bool], outer_size, outer_stride, ) -> torch.Tensor: @@ -186,7 +186,9 @@ class SparseSemiStructuredTensor(torch.Tensor): meta=inner_tensors.get("meta", None), packed_t=inner_tensors.get("packed_t", None), meta_t=inner_tensors.get("meta_t", None), - compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None), + compressed_swizzled_bitmask=inner_tensors.get( + "compressed_swizzled_bitmask", None + ), fuse_transpose_cusparselt=fuse_transpose_cusparselt, alg_id_cusparselt=alg_id_cusparselt, requires_grad=requires_grad, @@ -227,7 +229,7 @@ class SparseSemiStructuredTensor(torch.Tensor): cls.SPARSE_DISPATCH.update(custom_dispatch_table) @classmethod - def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None: + def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None: """ Assert that the given tensor is valid for semi-structured sparse compression. """ @@ -297,7 +299,7 @@ class SparseSemiStructuredTensor(torch.Tensor): return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) @classmethod - def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor": + def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor": raise NotImplementedError def _mm( @@ -377,6 +379,7 @@ def to_sparse_semi_structured( return SPARSE_SUBCLASS.from_dense(original_tensor) + class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): """ This class implements semi-structured sparsity for the CUTLASS backend. @@ -388,6 +391,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and sparse_semi_structured_from_dense for conversion to the compressed format. """ + BACKEND = "cutlass" _DTYPE_SHAPE_CONSTRAINTS = { torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16), @@ -417,13 +421,19 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): def to_dense(self): assert self.meta is not None and self.packed is not None - return sparse_semi_structured_to_dense_cutlass( - self.packed, - self.meta, - ) if self.meta.ndim == 2 else super().to_dense() + return ( + sparse_semi_structured_to_dense_cutlass( + self.packed, + self.meta, + ) + if self.meta.ndim == 2 + else super().to_dense() + ) @classmethod - def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor": + def prune_dense_static_sort( + cls, original_tensor: torch.Tensor, algorithm="" + ) -> "SparseSemiStructuredTensor": """ This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile. @@ -463,10 +473,15 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): ``` """ # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. - (packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile( - original_tensor, - algorithm=algorithm, - use_cutlass=True) + ( + packed, + meta, + packed_t, + meta_t, + compressed_swizzled_bitmask, + ) = torch._sparse_semi_structured_tile( + original_tensor, algorithm=algorithm, use_cutlass=True + ) return cls( original_tensor.shape, @@ -479,11 +494,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): ) def _mm( - self, - B: torch.Tensor, - *, - bias: Optional[torch.Tensor] = None, - **kwargs + self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: if isinstance(B, SparseSemiStructuredTensor): raise ValueError( @@ -500,9 +511,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): ) else: if bias is None: - res = torch._sparse_semi_structured_mm( - self.packed, self.meta, B - ) + res = torch._sparse_semi_structured_mm(self.packed, self.meta, B) else: res = torch._sparse_semi_structured_addmm( bias, self.packed, self.meta, B @@ -521,6 +530,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes. """ + BACKEND = "cusparselt" _DTYPE_SHAPE_CONSTRAINTS = { torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), @@ -530,7 +540,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): } @classmethod - def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT": + def from_dense( + cls, original_tensor: torch.Tensor + ) -> "SparseSemiStructuredTensorCUSPARSELT": cls._validate_device_dim_dtype_shape(original_tensor) return cls( shape=original_tensor.shape, @@ -545,7 +557,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): ) @classmethod - def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor": + def prune_dense_static_sort( + cls, original_tensor: torch.Tensor, algorithm="" + ) -> "SparseSemiStructuredTensor": """ This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata layout and sparse matmul. @@ -576,10 +590,15 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) ``` """ - (packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile( - original_tensor, - algorithm=algorithm, - use_cutlass=False) + ( + packed, + meta, + packed_t, + meta_t, + compressed_swizzled_bitmask, + ) = torch._sparse_semi_structured_tile( + original_tensor, algorithm=algorithm, use_cutlass=False + ) return cls( original_tensor.shape, @@ -592,11 +611,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): ) def _mm( - self, - B: torch.Tensor, - *, - bias: Optional[torch.Tensor] = None, - **kwargs + self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: if isinstance(B, SparseSemiStructuredTensor): raise ValueError(