Files
pytorch/benchmarks/sparse/dlmc/utils.py
Alexander 39f50f468d matmul performance benchmarks (#51647)
Summary:
Minor PR following up the previous PR about sparse benchmarking utils https://github.com/pytorch/pytorch/pull/48397

Fixes https://github.com/pytorch/pytorch/issues/44634:  Performance benchmarks for matrix-matrix and matrix-vector ops (dense-sparse, sparse-sparse, and compare to dense-dense)

I ran  all benchmarks on an 2xRTX8000 machine with AMD 2970WX 24-cores for `DLMC/magnitude_pruning` dataset with different sparsity levels.

 ---
<details><summary> forward tests (expand for details).
</summary>

- `sparse@sparse`
```
[------------------------------- cpu:matmul-forward -------------------------------]
                           |   0.5   |   0.7   |   0.8   |   0.9   |  0.95  |   0.98
1 threads: -------------------------------------------------------------------------
      torch:dense@dense    |  108.1  |  100.5  |  101.3  |  108.4  |  98.4  |  187.4
      torch:sparse@sparse  |  659.1  |  368.8  |  156.5  |   53.3  |  26.8  |   14.9
      scipy:sparse@sparse  |  565.1  |  233.9  |  130.2  |   23.1  |  21.6  |   15.2

Times are in milliseconds (ms).

[----------------------------------- cuda:matmul-forward -----------------------------------]
                           |    0.5    |    0.7    |   0.8    |   0.9    |   0.95   |   0.98
1 threads: ----------------------------------------------------------------------------------
      torch:dense@dense    |   2243.5  |   4392.5  |  4419.8  |  2272.3  |  4433.9  |  8920.1
      torch:sparse@sparse  |  21369.2  |  11877.6  |  7339.2  |  1787.2  |  1335.1  |   845.7

Times are in microseconds (us).

```
- `sparse@dense`
```
[------------------------------- cpu:matmul-forward -------------------------------]
                          |   0.5   |   0.7   |   0.8   |   0.9   |   0.95  |   0.98
1 threads: -------------------------------------------------------------------------
      torch:dense@dense   |  105.8  |  103.8  |  103.0  |  104.4  |  104.4  |  197.0
      torch:sparse@dense  |  119.9  |  102.4  |   84.0  |   19.7  |   16.8  |   11.6
      scipy:sparse@dense  |  906.5  |  799.6  |  697.8  |  182.2  |  165.5  |  135.4

Times are in milliseconds (ms).

[------------------------- cuda:matmul-forward --------------------------]
                          |  0.5  |  0.7  |  0.8  |  0.9  |  0.95  |  0.98
1 threads: ---------------------------------------------------------------
      torch:dense@dense   |  2.2  |  4.4  |  4.4  |  2.3  |  4.5   |  2.3
      torch:sparse@dense  |  5.7  |  6.6  |  4.5  |  1.4  |  1.4   |  1.3

Times are in milliseconds (ms).

```
- `sparse@vector`
```
[----------------------------------- cpu:matmul-forward ----------------------------------]
                           |    0.5    |   0.7    |   0.8    |   0.9    |   0.95   |   0.98
1 threads: --------------------------------------------------------------------------------
      torch:dense@vector   |    510.6  |   505.8  |   759.6  |   782.1  |   682.4  |  764.6
      torch:sparse@vector  |  10122.8  |  6241.1  |  7935.6  |  2076.3  |  1049.5  |  826.3
      scipy:sparse@vector  |   1756.7  |  1033.9  |   678.2  |   343.5  |   168.5  |   65.4

Times are in microseconds (us).

[-------------------------------- cuda:matmul-forward --------------------------------]
                           |   0.5    |   0.7    |   0.8   |   0.9   |   0.95  |   0.98
1 threads: ----------------------------------------------------------------------------
      torch:dense@vector   |    36.1  |    21.5  |   21.6  |   21.5  |   21.6  |   21.5
      torch:sparse@vector  |  1099.2  |  1289.4  |  775.7  |  327.1  |  285.4  |  274.0

Times are in microseconds (us).

```
</details>

 ---
<details><summary> backward tests (expand for details).
</summary>

- `sparse@sparse`
```
[--------------------------------- cpu:matmul-backward ---------------------------------]
                           |   0.5    |   0.7    |   0.8    |   0.9    |   0.95  |   0.98
1 threads: ------------------------------------------------------------------------------
      torch:dense@dense    |   246.1  |   315.0  |   306.9  |   168.6  |  290.6  |  146.9
      torch:sparse@sparse  |  6417.5  |  4393.7  |  3012.7  |  1029.4  |  908.0  |  650.7

Times are in microseconds (us).

[----------------------------- cuda:matmul-backward -----------------------------]
                           |   0.5   |   0.7   |   0.8   |  0.9   |  0.95  |  0.98
1 threads: -----------------------------------------------------------------------
      torch:dense@dense    |    6.7  |   13.3  |   13.3  |   6.9  |  13.5  |   6.9
      torch:sparse@sparse  |  143.7  |  143.4  |  119.6  |  29.5  |  29.1  |  10.9

Times are in microseconds (us).

```
- `sparse@dense`
```
 [------------------------------ cpu:matmul-backward -------------------------------]
                          |   0.5   |   0.7   |   0.8   |   0.9   |   0.95  |   0.98
1 threads: -------------------------------------------------------------------------
      torch:dense@dense   |  185.9  |  304.8  |  305.8  |  169.9  |  308.7  |  168.4
      torch:sparse@dense  |  407.9  |  345.8  |  274.6  |  114.2  |  163.6  |  230.5

Times are in milliseconds (ms).

[--------------------------- cuda:matmul-backward --------------------------]
                          |  0.5   |  0.7   |  0.8   |  0.9  |  0.95  |  0.98
1 threads: ------------------------------------------------------------------
      torch:dense@dense   |   6.7  |  13.3  |  13.3  |  6.9  |  13.4  |   6.9
      torch:sparse@dense  |  16.7  |  19.0  |  15.1  |  6.3  |   8.2  |  12.7

Times are in milliseconds (ms).

```
</details>

Kindly review this  PR. cc mruberry, ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51647

Reviewed By: albanD

Differential Revision: D27007809

Pulled By: mruberry

fbshipit-source-id: 8c1922cb3280027ca5e3eef31bfa20500c548cfd
2021-03-14 00:25:45 -08:00

200 lines
7.0 KiB
Python

import torch
from pathlib import Path
from scipy import sparse
import math
def to_coo_scipy(x):
indices_1 = x._indices().numpy()
values_1 = x._values().numpy()
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
shape=x.shape)
def sparse_grad_output(a, b):
c = torch.sparse.mm(a, b)
if c.is_sparse:
c2 = torch.rand_like(c.to_dense())
return c2.sparse_mask(c.coalesce())
else:
return torch.rand_like(c)
def read_matrix_params(path):
with open(path, 'r') as file:
line = file.readline()
nrows, ncols, nnz = map(lambda el: int(el), line.split(', '))
return (nrows, ncols), nnz
def csr_to_coo(indices, indptr, shape):
n_rows, n_cols = shape
cols = indices
rows = [0] * len(cols)
for i in range(n_rows):
for j in range(indptr[i], indptr[i + 1]):
rows[j] = i
return torch.tensor([rows, cols], dtype=torch.long)
def load_sparse_matrix(path, device):
with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split())
indices = map(lambda el: int(el), file.readline().split())
index_pointers = list(index_pointers)
indices = list(indices)
data = torch.randn(nnz, dtype=torch.double)
shape = (nrows, ncols)
return torch.sparse_coo_tensor(csr_to_coo(indices, index_pointers, shape), data, shape, device=device)
def gen_vector(path, device):
with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split())
indices = map(lambda el: int(el), file.readline().split())
return torch.randn(nrows, dtype=torch.double, device=device)
def gen_matrix(path, device):
with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split())
indices = map(lambda el: int(el), file.readline().split())
return torch.randn(nrows, ncols, dtype=torch.double, device=device)
def load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit=math.inf):
"""load_spmv_dataset loads a DLMC dataset for a sparse matrix-vector multiplication (SPMV) performance test.
Args:
dataset_path:
path of the dataset from DLMC collection.
hidden_size
This value allows tensors of varying sizes.
sparsity:
This value allows tensors of varying sparsities.
device:
Whether to place the Tensor on a GPU or CPU.
n_limit:
This value allows a dataset with some limit size.
"""
current_folder_path = f"{dataset_path}/{sparsity}"
path = Path(current_folder_path)
files = path.glob('**/*.smtx')
print(dataset_path, hidden_size, sparsity)
index = 0
x_files, y_files = [], []
for f in files:
if index >= n_limit:
break
print('.', end='')
size, nnz = read_matrix_params(f.as_posix())
if size[1] == hidden_size:
x_files.append(f.as_posix())
if size[0] == hidden_size:
y_files.append(f.as_posix())
index += 1
print()
for fx, fy in zip(x_files, y_files):
x = load_sparse_matrix(fx, device)
y = gen_vector(fy, device)
yield (x, y)
def load_spmm_dataset(dataset_path, hidden_size, sparsity, spmm_type, device, n_limit=math.inf):
"""load_spmm_dataset loads a DLMC dataset for a sparse matrix-matrix multiplication (SPMM) performance test.
Args:
dataset_path:
path of the dataset from DLMC collection.
hidden_size
This value allows tensors of varying sizes.
sparsity:
This value allows tensors of varying sparsities.
spmm_type:
This value allows tensors for `sparse@sparse` or `sparse@dense` operations.
device:
Whether to place the Tensor on a GPU or CPU.
n_limit:
This value allows a dataset with some limit size.
"""
current_folder_path = f"{dataset_path}/{sparsity}"
path = Path(current_folder_path)
files = path.glob('**/*.smtx')
print(dataset_path, hidden_size, sparsity)
index = 0
x_files, y_files = [], []
for f in files:
if index >= n_limit:
break
print('.', end='')
size, nnz = read_matrix_params(f.as_posix())
if size[1] == hidden_size:
x_files.append(f.as_posix())
if size[0] == hidden_size:
y_files.append(f.as_posix())
index += 1
print()
for fx, fy in zip(x_files, y_files):
x = load_sparse_matrix(fx, device)
y = gen_matrix(fy, device) if spmm_type == 'sparse@dense' else load_sparse_matrix(fy, device)
yield (x, y)
def load_dlmc_dataset(dataset_path, operation, hidden_size, sparsity, device, requires_grad, n_limit=math.inf):
"""load_dlmc_dataset loads a DLMC dataset for a matmul performance test.
Args:
dataset_path:
path of the dataset from DLMC collection.
operation:
This value allows tensors for `sparse@sparse`|`sparse@dense`|`sparse@vector` operations.
hidden_size
This value allows tensors of varying sizes.
sparsity:
This value allows tensors of varying sparsities.
device:
Whether to place the Tensor on a GPU or CPU.
requires_grad:
Loads the dataset for backward test.
n_limit:
This value allows a dataset with some limit size.
"""
if operation == 'sparse@sparse' or operation == "sparse@dense":
collection = load_spmm_dataset(dataset_path, hidden_size, sparsity, operation, device, n_limit)
elif operation == 'sparse@vector':
collection = load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit)
scipy_vars = {}
backward_vars = {}
for x, y in collection:
if device == 'cpu':
scipy_vars = {
"sx": to_coo_scipy(x) if x.is_sparse else x.numpy(),
"sy": to_coo_scipy(y) if y.is_sparse else y.numpy(),
}
if not requires_grad:
dx = x.to_dense() if x.is_sparse else x
dy = y.to_dense() if y.is_sparse else y
else:
c = sparse_grad_output(x, y)
backward_vars = {
"sparse_grad_output": c,
"grad_output": c.to_dense() if c.is_sparse else c,
}
x.requires_grad_(True)
y.requires_grad_(True)
dx = x.to_dense().detach() if x.is_sparse else x.clone().detach()
dy = y.to_dense().detach() if y.is_sparse else y.clone().detach()
dx.requires_grad_(True)
dy.requires_grad_(True)
yield {
"x": x,
"y": y,
"dx": dx,
"dy": dy,
**scipy_vars,
**backward_vars
}