Sparse-sparse matrix multiplication (CPU/CUDA) (#39526)

Summary:
This PR implements matrix multiplication support for 2-d sparse tensors using the COO sparse format.

The current implementation of `torch.sparse.mm` support this configuration,
`torch.sparse.mm(sparse_matrix1, sparse_matrix2.to_dense())`, but this could spend a lot of memory when sparse_matrix2's shape is large.

This implementation extends `torch.sparse.mm` function to support  `torch.sparse.mm(sparse_matrix1, sparse_matrix2)`

Resolves  #[20988](https://github.com/pytorch/pytorch/issues/20988) for CPU/CUDA.

- [x] sparse matmul
  - [x] CPU/CUDA C++ implementation
  - [x] unittests
  - [x] update torch.sparse.mm documentation
  - [x] autograd support

The CPU sparse-sparse matmul was implemented taking as a reference this work "Sparse Matrix Multiplication Package (SMMP)". The GPU sparse-sparse matmul is based on cuSparse, there is specific code for CUSPARSE when CUSPARSE_VERSION >= 11 and old version of CUSPARSE. Both CPU/CUDA  rely on the sparse-sparse matmul algorithm using the CSR indices format as it is one of the fastest algorithm.

Here it is the latest benchmark (script is here) results for torch.sparse.mm (CUDA) and torch.sparse.mm (CPU) and scipy, values are float32 scalars:

size | density | sparse.mm(CUDA) | sparse.mm(CPU) | scipy_coo_matmul
-- | -- | -- | -- | --
(32, 10000) | 0.01 | 822.7 | 79.4 | 704.1
(32, 10000) | 0.05 | 1741.1 | 402.6 | 1155.3
(32, 10000) | 0.1 | 2956.8 | 840.8 | 1885.4
(32, 10000) | 0.25 | 6417.7 | 2832.3 | 4665.2
(512, 10000) | 0.01 | 1010.2 | 3941.3 | 26937.7
(512, 10000) | 0.05 | 2216.2 | 26903.8 | 57343.7
(512, 10000) | 0.1 | 4868.4 | 87773.7 | 117477.0
(512, 10000) | 0.25 | 16639.3 | 608105.0 | 624290.4
(1024, 10000) | 0.01 | 1224.8 | 13088.1 | 110379.2
(1024, 10000) | 0.05 | 3897.5 | 94783.9 | 236541.8
(1024, 10000) | 0.1 | 10559.1 | 405312.5 | 525483.4
(1024, 10000) | 0.25 | 57456.3 | 2424337.5 | 2729318.7

A new backward algorithm was implemented using only `sparse @ sparse` and `sparse_mask` operations. Here is some benchmarking:

```
[------------------------- sparse.mm-backward -------------------------]
                            |   sparse.backward   |  dense.backward
 -----------------------------------------------------------------------
      (32, 10000) | 0.01    |            13.5          |         2.4
      (32, 10000) | 0.05    |            52.3          |         2.4
      (512, 10000) | 0.01   |          1016.8          |       491.5
      (512, 10000) | 0.05   |          1604.3          |       492.3
      (1024, 10000) | 0.01  |          2384.1          |      1963.7
      (1024, 10000) | 0.05  |          3965.8          |      1951.9
```

I added new benchmark tests. Now I am using a real dataset used in recent studies [1, 2] with different sparsity levels.

```
[---------------------------------- matmul ---------------------------------]
                        |   0.5   |  0.7   |  0.8   |  0.9   |  0.95  |  0.98
1 threads: ------------------------------------------------------------------
  (cpu)   torch         |    5.4  |   5.4  |   5.2  |   5.3  |   5.3  |   5.4
          torch.sparse  |  122.2  |  51.9  |  27.5  |  11.4  |   4.9  |   1.8
          scipy         |  150.1  |  87.4  |  69.2  |  56.8  |  38.4  |  17.1
  (cuda)  torch         |    1.3  |   1.1  |   1.1  |   1.1  |   1.1  |   1.1
          torch.sparse  |   20.0  |   8.4  |   5.1  |   2.5  |   1.5  |   1.1

[----------------------------------- backward -----------------------------------]
                        |   0.5   |   0.7   |   0.8   |   0.9   |   0.95  |   0.98
1 threads: -----------------------------------------------------------------------
  (cpu)   torch         |   17.7  |   17.9  |   17.7  |   17.7  |   17.6  |   17.9
          torch.sparse  |  672.9  |  432.6  |  327.5  |  230.8  |  176.7  |  116.7
  (cuda)  torch         |    3.8  |    3.6  |    3.5  |    3.5  |    3.6  |    3.5
          torch.sparse  |   68.8  |   46.2  |   35.6  |   24.2  |   17.8  |   11.9

Times are in milliseconds (ms).
```

In summary, I can say that the new `sparse @ sparse` backward algorithm is better as it is more about saving space than performance. Moreover, it is better than other options tested before.

## **References**

1. Trevor Gale, Matei Zaharia, Cliff Young, Erich Elsen. **Sparse GPU Kernels for Deep Learning.**  Proceedings of the International Conference for High Performance Computing, 2020. [https://github.com/google-research/google-research/tree/master/sgk](https://github.com/google-research/google-research/tree/master/sgk)
2. Trevor Gale, Erich Elsen, Sara Hooker. **The State of Sparsity in Deep Neural Networks.** [https://github.com/google-research/google-research/tree/master/state_of_sparsity](https://github.com/google-research/google-research/tree/master/state_of_sparsity)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/39526

Reviewed By: mruberry

Differential Revision: D25661239

Pulled By: ngimel

fbshipit-source-id: b515ecd66d25f347d637e159d51aa45fb43b6938
This commit is contained in:
Alexander
2020-12-21 11:51:52 -08:00
committed by Facebook GitHub Bot
parent 3779bdec56
commit 44ce0b8883
19 changed files with 1849 additions and 145 deletions

View File

@ -0,0 +1,198 @@
# Sparse benchmarks
# These benchmarks are for the sparse matrix functionality.
# They exist for comparing the performance of sparse matrix routines
# torch.sparse.mm(sparse, sparse)` with different backends (CPU/CUDA)
# and with other frameworks such as scipy.
import sys
from scipy import sparse
import numpy as np
from pathlib import Path
import pandas as pd
import argparse
import torch
import torch.utils.benchmark as benchmark_utils
def read_matrix_params(path):
sys.stdin = open(path)
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
return (nrows, ncols), nnz
def load_matrix(path):
sys.stdin = open(path)
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
index_pointers = map(lambda el: int(el), input().split())
indices = map(lambda el: int(el), input().split())
index_pointers = list(index_pointers)
indices = list(indices)
data = np.random.rand(nnz)
coo = sparse.csr_matrix(
(data, np.array(indices), np.array(index_pointers)),
shape=(nrows, ncols)).tocoo()
return torch.sparse_coo_tensor([coo.row, coo.col], coo.data, coo.shape)
def scipy_coo_matmul(mat1, mat2):
result = mat1.dot(mat2).tocoo()
return torch.sparse_coo_tensor([result.row, result.col], result.data,
result.shape)
def to_coo_scipy(x):
indices_1 = x._indices().numpy()
values_1 = x._values().numpy()
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
shape=x.shape)
def torch_backward(a_dense, b_dense):
a_dense.requires_grad = True
b_dense.requires_grad = True
r1 = a_dense.matmul(b_dense)
f1 = torch.sum(r1)
f1.backward()
def sparse_torch_backward(a, b):
a.requires_grad = True
b.requires_grad = True
r2 = torch.sparse.mm(a, b)
f2 = torch.sparse.sum(r2)
f2.backward()
def load_dataset(dataset_path, hidden_size, sparsity, n_limit=20):
current_folder_path = f"{dataset_path}/{sparsity}"
path = Path(current_folder_path)
files = path.glob('**/*.smtx')
xs = []
ys = []
print(dataset_path, hidden_size, sparsity)
index = 0
for elem in files:
if index == n_limit:
break
print('.', end='')
size, nnz = read_matrix_params(elem.as_posix())
if size[1] == hidden_size:
xs.append(load_matrix(elem.as_posix()))
if size[0] == hidden_size:
ys.append(load_matrix(elem.as_posix()))
index += 1
print()
return zip(xs, ys)
if __name__ == '__main__':
path = Path()
parser = argparse.ArgumentParser(description='Sparse Matmul Bench')
parser.add_argument('--path', type=str, help='dataset path')
parser.add_argument('--dataset',
type=str,
help='dataset name',
default='random_pruning')
parser.add_argument('--operation',
type=str,
help='matmul or backward',
default='matmul')
parser.add_argument('--output',
type=str,
help='dataframe output path',
default='/tmp/matmul_bench.pkl')
args = parser.parse_args()
print('path =', args.path)
print('dataset =', args.dataset)
print('operation =', args.operation)
print('output =', args.output)
dataset_path = args.path
dataset_name = args.dataset
dataset_path = f"{dataset_path}/{dataset_name}"
df_output_path = args.output
tasks = []
if args.operation == 'matmul':
tasks = [
("matmul", "cpu", "torch", "torch.mm(dense_x, dense_y)"),
("matmul", "cpu", "torch.sparse", "torch.sparse.mm(tx, ty)"),
("matmul", "cpu", "scipy",
"scipy_coo_matmul(scipy_varx, scipy_vary)"),
("matmul", "cuda", "torch",
"torch.mm(dense_cuda_x, dense_cuda_y)"),
("matmul", "cuda", "torch.sparse",
"torch.sparse.mm(tx_cuda, ty_cuda)"),
]
else:
tasks = [
("backward", "cpu", "torch", "torch_backward(dense_x, dense_y)"),
("backward", "cpu", "torch.sparse",
"sparse_torch_backward(tx, ty)"),
("backward", "cuda", "torch",
"torch_backward(dense_cuda_x, dense_cuda_y)"),
("backward", "cuda", "torch.sparse",
"sparse_torch_backward(tx_cuda, ty_cuda)"),
]
serialized_results = []
repeats = 2
timers = [
benchmark_utils.Timer(
stmt=stmt,
globals={
"scipy_coo_matmul": scipy_coo_matmul,
"torch_backward": torch_backward,
"sparse_torch_backward": sparse_torch_backward,
"scipy_varx": to_coo_scipy(x),
"scipy_vary": to_coo_scipy(y),
"tx": x,
"ty": y,
"tx_cuda": x.cuda(),
"ty_cuda": y.cuda(),
"dense_cuda_x": x.to_dense().cuda(),
"dense_cuda_y": y.to_dense().cuda(),
"dense_x": x.to_dense(),
"dense_y": y.to_dense(),
},
label=label,
sub_label=sub_label,
description=f"{sparsity}",
env=device,
# num_threads=num_threads,
) for hidden_size in [512]
for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
for label, device, sub_label, stmt in tasks
for num_threads in [1, 4, 8, 16]
for x, y in load_dataset(dataset_path, hidden_size, sparsity)
]
measurements = []
for i, timer in enumerate(timers * repeats):
m = timer.blocked_autorange(min_run_time=0.05)
serialized_results.append(pickle.dumps(m))
m.metadata = {
"device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu'
}
measurements.append(m)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
print()
comparison = benchmark_utils.Compare(
[pickle.loads(i) for i in serialized_results])
print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
comparison.print()
print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
comparison.trim_significant_figures()
comparison.colorize()
comparison.print()
table = [(m.task_spec.sub_label, m.task_spec.description,
m.metadata["device"], m.mean) for m in measurements]
df = pd.DataFrame(table, columns=['method', 'sparsity', 'device', 'time'])
df.to_pickle(df_output_path)

14
benchmarks/sparse/test.sh Normal file
View File

@ -0,0 +1,14 @@
#!/bin/bash
DATASET_ROOT_DIR=$HOME/datasets/
# wget https://storage.googleapis.com/sgk-sc2020/dlmc.tar.gz -P $DATASET_ROOT_DIR
# tar -xvf $DATASET_ROOT_DIR/dlmc.tar.gz
echo "!! SPARSE SPMS TIME BENCHMARK!! "
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation matmul --output /tmp/matmul_bench.pkl
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation backward --output /tmp/backward_bench.pkl
python plot_results.py -i /tmp/matmul_bench.pkl
python plot_results.py -i /tmp/backward_bench.pkl