matmul performance benchmarks (#51647)

Summary:
Minor PR following up the previous PR about sparse benchmarking utils https://github.com/pytorch/pytorch/pull/48397

Fixes https://github.com/pytorch/pytorch/issues/44634:  Performance benchmarks for matrix-matrix and matrix-vector ops (dense-sparse, sparse-sparse, and compare to dense-dense)

I ran  all benchmarks on an 2xRTX8000 machine with AMD 2970WX 24-cores for `DLMC/magnitude_pruning` dataset with different sparsity levels.

 ---
<details><summary> forward tests (expand for details).
</summary>

- `sparse@sparse`
```
[------------------------------- cpu:matmul-forward -------------------------------]
                           |   0.5   |   0.7   |   0.8   |   0.9   |  0.95  |   0.98
1 threads: -------------------------------------------------------------------------
      torch:dense@dense    |  108.1  |  100.5  |  101.3  |  108.4  |  98.4  |  187.4
      torch:sparse@sparse  |  659.1  |  368.8  |  156.5  |   53.3  |  26.8  |   14.9
      scipy:sparse@sparse  |  565.1  |  233.9  |  130.2  |   23.1  |  21.6  |   15.2

Times are in milliseconds (ms).

[----------------------------------- cuda:matmul-forward -----------------------------------]
                           |    0.5    |    0.7    |   0.8    |   0.9    |   0.95   |   0.98
1 threads: ----------------------------------------------------------------------------------
      torch:dense@dense    |   2243.5  |   4392.5  |  4419.8  |  2272.3  |  4433.9  |  8920.1
      torch:sparse@sparse  |  21369.2  |  11877.6  |  7339.2  |  1787.2  |  1335.1  |   845.7

Times are in microseconds (us).

```
- `sparse@dense`
```
[------------------------------- cpu:matmul-forward -------------------------------]
                          |   0.5   |   0.7   |   0.8   |   0.9   |   0.95  |   0.98
1 threads: -------------------------------------------------------------------------
      torch:dense@dense   |  105.8  |  103.8  |  103.0  |  104.4  |  104.4  |  197.0
      torch:sparse@dense  |  119.9  |  102.4  |   84.0  |   19.7  |   16.8  |   11.6
      scipy:sparse@dense  |  906.5  |  799.6  |  697.8  |  182.2  |  165.5  |  135.4

Times are in milliseconds (ms).

[------------------------- cuda:matmul-forward --------------------------]
                          |  0.5  |  0.7  |  0.8  |  0.9  |  0.95  |  0.98
1 threads: ---------------------------------------------------------------
      torch:dense@dense   |  2.2  |  4.4  |  4.4  |  2.3  |  4.5   |  2.3
      torch:sparse@dense  |  5.7  |  6.6  |  4.5  |  1.4  |  1.4   |  1.3

Times are in milliseconds (ms).

```
- `sparse@vector`
```
[----------------------------------- cpu:matmul-forward ----------------------------------]
                           |    0.5    |   0.7    |   0.8    |   0.9    |   0.95   |   0.98
1 threads: --------------------------------------------------------------------------------
      torch:dense@vector   |    510.6  |   505.8  |   759.6  |   782.1  |   682.4  |  764.6
      torch:sparse@vector  |  10122.8  |  6241.1  |  7935.6  |  2076.3  |  1049.5  |  826.3
      scipy:sparse@vector  |   1756.7  |  1033.9  |   678.2  |   343.5  |   168.5  |   65.4

Times are in microseconds (us).

[-------------------------------- cuda:matmul-forward --------------------------------]
                           |   0.5    |   0.7    |   0.8   |   0.9   |   0.95  |   0.98
1 threads: ----------------------------------------------------------------------------
      torch:dense@vector   |    36.1  |    21.5  |   21.6  |   21.5  |   21.6  |   21.5
      torch:sparse@vector  |  1099.2  |  1289.4  |  775.7  |  327.1  |  285.4  |  274.0

Times are in microseconds (us).

```
</details>

 ---
<details><summary> backward tests (expand for details).
</summary>

- `sparse@sparse`
```
[--------------------------------- cpu:matmul-backward ---------------------------------]
                           |   0.5    |   0.7    |   0.8    |   0.9    |   0.95  |   0.98
1 threads: ------------------------------------------------------------------------------
      torch:dense@dense    |   246.1  |   315.0  |   306.9  |   168.6  |  290.6  |  146.9
      torch:sparse@sparse  |  6417.5  |  4393.7  |  3012.7  |  1029.4  |  908.0  |  650.7

Times are in microseconds (us).

[----------------------------- cuda:matmul-backward -----------------------------]
                           |   0.5   |   0.7   |   0.8   |  0.9   |  0.95  |  0.98
1 threads: -----------------------------------------------------------------------
      torch:dense@dense    |    6.7  |   13.3  |   13.3  |   6.9  |  13.5  |   6.9
      torch:sparse@sparse  |  143.7  |  143.4  |  119.6  |  29.5  |  29.1  |  10.9

Times are in microseconds (us).

```
- `sparse@dense`
```
 [------------------------------ cpu:matmul-backward -------------------------------]
                          |   0.5   |   0.7   |   0.8   |   0.9   |   0.95  |   0.98
1 threads: -------------------------------------------------------------------------
      torch:dense@dense   |  185.9  |  304.8  |  305.8  |  169.9  |  308.7  |  168.4
      torch:sparse@dense  |  407.9  |  345.8  |  274.6  |  114.2  |  163.6  |  230.5

Times are in milliseconds (ms).

[--------------------------- cuda:matmul-backward --------------------------]
                          |  0.5   |  0.7   |  0.8   |  0.9  |  0.95  |  0.98
1 threads: ------------------------------------------------------------------
      torch:dense@dense   |   6.7  |  13.3  |  13.3  |  6.9  |  13.4  |   6.9
      torch:sparse@dense  |  16.7  |  19.0  |  15.1  |  6.3  |   8.2  |  12.7

Times are in milliseconds (ms).

```
</details>

Kindly review this  PR. cc mruberry, ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51647

Reviewed By: albanD

Differential Revision: D27007809

Pulled By: mruberry

fbshipit-source-id: 8c1922cb3280027ca5e3eef31bfa20500c548cfd
This commit is contained in:
Alexander
2021-03-14 00:15:12 -08:00
committed by Facebook GitHub Bot
parent 142c6b0e55
commit 39f50f468d
7 changed files with 370 additions and 212 deletions

View File

@ -0,0 +1,15 @@
# Sparse benchmarks
These sets of benchmarks are for the sparse matrix functionality using a popular real dataset collection called the Deep Learning Matrix Collection (DLMC), which were used in recent studies [1, 2].
Performance benchmarks scripts for matrix-matrix and matrix-vector ops (dense-sparse, sparse-sparse, and compare to dense-dense) are implemented here.
- `matmul_bench.py` with `--operation sparse@sparse|sparse@dense` is for Sparse matrix-matrix multiplication (SPMM) performance test. It can run in forward and backward mode with `--backward_test`, on CPU or CUDA with `--with_cuda`, using different datasets from the dataset collection DLMC. For more details see `test.sh` file.
- `matmul_bench.py` with `--operation sparse@vector` is for Sparse matrix-vector multiplication (SPMV) performance test.
References:
1. Trevor Gale, Matei Zaharia, Cliff Young, Erich Elsen. Sparse GPU Kernels for Deep Learning. Proceedings of the International Conference for High Performance Computing, 2020. https://github.com/google-research/google-research/tree/master/sgk
2. Trevor Gale, Erich Elsen, Sara Hooker. The State of Sparsity in Deep Neural Networks. https://github.com/google-research/google-research/tree/master/state_of_sparsity

View File

@ -0,0 +1,3 @@
if __name__ == "__main__":
pass

View File

@ -0,0 +1,126 @@
# Sparse benchmarks
# This benchmark is for sparse matmul performance test.
# They exist for comparing the performance of sparse matrix routines
# `sparse @ vector`, `sparse @ sparse` and `sparse @ dense` with different backends (CPU/CUDA)
# and with other frameworks such as scipy.
import sys
import argparse
import torch
import torch.utils.benchmark as benchmark_utils
from .utils import load_dlmc_dataset
from scipy.sparse import isspmatrix
import os
def scipy_matmul(mat1, mat2):
if isspmatrix(mat1) and isspmatrix(mat2):
return mat1.dot(mat2).tocoo()
return mat1.dot(mat2)
def matmul_backward(a_dense, b_dense, grad_output):
r1 = a_dense.matmul(b_dense)
r1.backward(grad_output)
def sparse_matmul_backward(a, b, grad_output):
c = torch.sparse.mm(a, b)
c.backward(grad_output)
OPS_MAP = {
"sparse@sparse": "torch.sparse.mm",
"sparse@dense": "torch.matmul",
"sparse@vector": "torch.matmul",
}
# also get the arguments as input from the user using `argparse`
def parse_args():
parser = argparse.ArgumentParser(description='matmul benchmark')
parser.add_argument('--path', type=str, help='DLMC dataset path')
parser.add_argument('--dataset', type=str, default='magnitude_pruning')
parser.add_argument('--hidden_size', default=2048, type=int)
parser.add_argument('--backward_test', action="store_true")
parser.add_argument('--operation', type=str, help="|".join(OPS_MAP.keys()), default=next(iter(OPS_MAP)))
parser.add_argument('--with_cuda', action='store_true')
parser.add_argument('--timer_min_run_time', default=1, type=float)
return parser
def get_tasks(op, backward_test, device):
def filter_ops(operation):
if backward_test:
test_name = device + ":matmul-backward"
return [
(test_name, device, "torch:" + operation.replace("sparse", "dense"),
"matmul_backward(dx, dy, grad_output)"),
(test_name, device, "torch:" + operation, "sparse_matmul_backward(x, y, sparse_grad_output)")
]
else:
test_name = device + ":matmul-forward"
return list(filter(None, [
(test_name, device, "torch:" + operation.replace("sparse", "dense"),
"{}(dx, dy)".format(OPS_MAP[operation])),
(test_name, device, "torch:" + operation, "{}(x, y)".format(OPS_MAP[operation])),
(test_name, device, "scipy:" + operation, "scipy_matmul(sx, sy)") if device == "cpu" else None
]))
all_operations = {
"sparse@sparse": filter_ops("sparse@sparse"),
"sparse@dense": filter_ops("sparse@dense"),
"sparse@vector": filter_ops("sparse@vector"),
}
return all_operations[op]
if __name__ == '__main__':
parser = parse_args()
args = parser.parse_args()
if args.with_cuda and not torch.cuda.is_available():
raise RuntimeError("No CUDA available")
dataset_path = args.path
dataset_name = args.dataset
dataset_path = os.path.join(dataset_path, dataset_name)
device = 'cuda' if args.with_cuda else 'cpu'
tasks = get_tasks(args.operation, args.backward_test, device)
repeats = 3
timers = [
benchmark_utils.Timer(
stmt=stmt,
globals={
"scipy_matmul": scipy_matmul,
"matmul_backward": matmul_backward,
"sparse_matmul_backward": sparse_matmul_backward,
**variables
},
label=label,
sub_label=sub_label,
description=f"{sparsity}",
env=device,
)
for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
for label, device, sub_label, stmt in tasks
for variables in
load_dlmc_dataset(dataset_path, args.operation, args.hidden_size, sparsity, device, args.backward_test)
]
measurements = []
for i, timer in enumerate(timers * repeats):
m = timer.blocked_autorange(min_run_time=args.timer_min_run_time)
m.metadata = {
"device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu'
}
measurements.append(m)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
print()
comparison = benchmark_utils.Compare(measurements)
print("== Results " + "=" * 80 + "\n" + "/" * 95 + "\n")
comparison.print()

View File

@ -0,0 +1,27 @@
#!/bin/bash
DATASET_ROOT_DIR=$HOME/datasets/
# wget https://storage.googleapis.com/sgk-sc2020/dlmc.tar.gz -P $DATASET_ROOT_DIR
# tar -xvf $DATASET_ROOT_DIR/dlmc.tar.gz
echo "!! SPARSE SPMS TIME BENCHMARK!! "
# cpu
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse --backward_test
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense --backward_test
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@vector
# cuda
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse --with_cuda
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse --with_cuda--backward_test
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense --with_cuda
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense --with_cuda --backward_test
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@vector --with_cuda

View File

@ -0,0 +1,199 @@
import torch
from pathlib import Path
from scipy import sparse
import math
def to_coo_scipy(x):
indices_1 = x._indices().numpy()
values_1 = x._values().numpy()
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
shape=x.shape)
def sparse_grad_output(a, b):
c = torch.sparse.mm(a, b)
if c.is_sparse:
c2 = torch.rand_like(c.to_dense())
return c2.sparse_mask(c.coalesce())
else:
return torch.rand_like(c)
def read_matrix_params(path):
with open(path, 'r') as file:
line = file.readline()
nrows, ncols, nnz = map(lambda el: int(el), line.split(', '))
return (nrows, ncols), nnz
def csr_to_coo(indices, indptr, shape):
n_rows, n_cols = shape
cols = indices
rows = [0] * len(cols)
for i in range(n_rows):
for j in range(indptr[i], indptr[i + 1]):
rows[j] = i
return torch.tensor([rows, cols], dtype=torch.long)
def load_sparse_matrix(path, device):
with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split())
indices = map(lambda el: int(el), file.readline().split())
index_pointers = list(index_pointers)
indices = list(indices)
data = torch.randn(nnz, dtype=torch.double)
shape = (nrows, ncols)
return torch.sparse_coo_tensor(csr_to_coo(indices, index_pointers, shape), data, shape, device=device)
def gen_vector(path, device):
with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split())
indices = map(lambda el: int(el), file.readline().split())
return torch.randn(nrows, dtype=torch.double, device=device)
def gen_matrix(path, device):
with open(path, 'r') as file:
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
index_pointers = map(lambda el: int(el), file.readline().split())
indices = map(lambda el: int(el), file.readline().split())
return torch.randn(nrows, ncols, dtype=torch.double, device=device)
def load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit=math.inf):
"""load_spmv_dataset loads a DLMC dataset for a sparse matrix-vector multiplication (SPMV) performance test.
Args:
dataset_path:
path of the dataset from DLMC collection.
hidden_size
This value allows tensors of varying sizes.
sparsity:
This value allows tensors of varying sparsities.
device:
Whether to place the Tensor on a GPU or CPU.
n_limit:
This value allows a dataset with some limit size.
"""
current_folder_path = f"{dataset_path}/{sparsity}"
path = Path(current_folder_path)
files = path.glob('**/*.smtx')
print(dataset_path, hidden_size, sparsity)
index = 0
x_files, y_files = [], []
for f in files:
if index >= n_limit:
break
print('.', end='')
size, nnz = read_matrix_params(f.as_posix())
if size[1] == hidden_size:
x_files.append(f.as_posix())
if size[0] == hidden_size:
y_files.append(f.as_posix())
index += 1
print()
for fx, fy in zip(x_files, y_files):
x = load_sparse_matrix(fx, device)
y = gen_vector(fy, device)
yield (x, y)
def load_spmm_dataset(dataset_path, hidden_size, sparsity, spmm_type, device, n_limit=math.inf):
"""load_spmm_dataset loads a DLMC dataset for a sparse matrix-matrix multiplication (SPMM) performance test.
Args:
dataset_path:
path of the dataset from DLMC collection.
hidden_size
This value allows tensors of varying sizes.
sparsity:
This value allows tensors of varying sparsities.
spmm_type:
This value allows tensors for `sparse@sparse` or `sparse@dense` operations.
device:
Whether to place the Tensor on a GPU or CPU.
n_limit:
This value allows a dataset with some limit size.
"""
current_folder_path = f"{dataset_path}/{sparsity}"
path = Path(current_folder_path)
files = path.glob('**/*.smtx')
print(dataset_path, hidden_size, sparsity)
index = 0
x_files, y_files = [], []
for f in files:
if index >= n_limit:
break
print('.', end='')
size, nnz = read_matrix_params(f.as_posix())
if size[1] == hidden_size:
x_files.append(f.as_posix())
if size[0] == hidden_size:
y_files.append(f.as_posix())
index += 1
print()
for fx, fy in zip(x_files, y_files):
x = load_sparse_matrix(fx, device)
y = gen_matrix(fy, device) if spmm_type == 'sparse@dense' else load_sparse_matrix(fy, device)
yield (x, y)
def load_dlmc_dataset(dataset_path, operation, hidden_size, sparsity, device, requires_grad, n_limit=math.inf):
"""load_dlmc_dataset loads a DLMC dataset for a matmul performance test.
Args:
dataset_path:
path of the dataset from DLMC collection.
operation:
This value allows tensors for `sparse@sparse`|`sparse@dense`|`sparse@vector` operations.
hidden_size
This value allows tensors of varying sizes.
sparsity:
This value allows tensors of varying sparsities.
device:
Whether to place the Tensor on a GPU or CPU.
requires_grad:
Loads the dataset for backward test.
n_limit:
This value allows a dataset with some limit size.
"""
if operation == 'sparse@sparse' or operation == "sparse@dense":
collection = load_spmm_dataset(dataset_path, hidden_size, sparsity, operation, device, n_limit)
elif operation == 'sparse@vector':
collection = load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit)
scipy_vars = {}
backward_vars = {}
for x, y in collection:
if device == 'cpu':
scipy_vars = {
"sx": to_coo_scipy(x) if x.is_sparse else x.numpy(),
"sy": to_coo_scipy(y) if y.is_sparse else y.numpy(),
}
if not requires_grad:
dx = x.to_dense() if x.is_sparse else x
dy = y.to_dense() if y.is_sparse else y
else:
c = sparse_grad_output(x, y)
backward_vars = {
"sparse_grad_output": c,
"grad_output": c.to_dense() if c.is_sparse else c,
}
x.requires_grad_(True)
y.requires_grad_(True)
dx = x.to_dense().detach() if x.is_sparse else x.clone().detach()
dy = y.to_dense().detach() if y.is_sparse else y.clone().detach()
dx.requires_grad_(True)
dy.requires_grad_(True)
yield {
"x": x,
"y": y,
"dx": dx,
"dy": dy,
**scipy_vars,
**backward_vars
}

View File

@ -1,198 +0,0 @@
# Sparse benchmarks
# These benchmarks are for the sparse matrix functionality.
# They exist for comparing the performance of sparse matrix routines
# torch.sparse.mm(sparse, sparse)` with different backends (CPU/CUDA)
# and with other frameworks such as scipy.
import sys
from scipy import sparse
import numpy as np
from pathlib import Path
import pandas as pd
import argparse
import torch
import torch.utils.benchmark as benchmark_utils
def read_matrix_params(path):
sys.stdin = open(path)
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
return (nrows, ncols), nnz
def load_matrix(path):
sys.stdin = open(path)
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
index_pointers = map(lambda el: int(el), input().split())
indices = map(lambda el: int(el), input().split())
index_pointers = list(index_pointers)
indices = list(indices)
data = np.random.rand(nnz)
coo = sparse.csr_matrix(
(data, np.array(indices), np.array(index_pointers)),
shape=(nrows, ncols)).tocoo()
return torch.sparse_coo_tensor([coo.row, coo.col], coo.data, coo.shape)
def scipy_coo_matmul(mat1, mat2):
result = mat1.dot(mat2).tocoo()
return torch.sparse_coo_tensor([result.row, result.col], result.data,
result.shape)
def to_coo_scipy(x):
indices_1 = x._indices().numpy()
values_1 = x._values().numpy()
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
shape=x.shape)
def torch_backward(a_dense, b_dense):
a_dense.requires_grad = True
b_dense.requires_grad = True
r1 = a_dense.matmul(b_dense)
f1 = torch.sum(r1)
f1.backward()
def sparse_torch_backward(a, b):
a.requires_grad = True
b.requires_grad = True
r2 = torch.sparse.mm(a, b)
f2 = torch.sparse.sum(r2)
f2.backward()
def load_dataset(dataset_path, hidden_size, sparsity, n_limit=20):
current_folder_path = f"{dataset_path}/{sparsity}"
path = Path(current_folder_path)
files = path.glob('**/*.smtx')
xs = []
ys = []
print(dataset_path, hidden_size, sparsity)
index = 0
for elem in files:
if index == n_limit:
break
print('.', end='')
size, nnz = read_matrix_params(elem.as_posix())
if size[1] == hidden_size:
xs.append(load_matrix(elem.as_posix()))
if size[0] == hidden_size:
ys.append(load_matrix(elem.as_posix()))
index += 1
print()
return zip(xs, ys)
if __name__ == '__main__':
path = Path()
parser = argparse.ArgumentParser(description='Sparse Matmul Bench')
parser.add_argument('--path', type=str, help='dataset path')
parser.add_argument('--dataset',
type=str,
help='dataset name',
default='random_pruning')
parser.add_argument('--operation',
type=str,
help='matmul or backward',
default='matmul')
parser.add_argument('--output',
type=str,
help='dataframe output path',
default='/tmp/matmul_bench.pkl')
args = parser.parse_args()
print('path =', args.path)
print('dataset =', args.dataset)
print('operation =', args.operation)
print('output =', args.output)
dataset_path = args.path
dataset_name = args.dataset
dataset_path = f"{dataset_path}/{dataset_name}"
df_output_path = args.output
tasks = []
if args.operation == 'matmul':
tasks = [
("matmul", "cpu", "torch", "torch.mm(dense_x, dense_y)"),
("matmul", "cpu", "torch.sparse", "torch.sparse.mm(tx, ty)"),
("matmul", "cpu", "scipy",
"scipy_coo_matmul(scipy_varx, scipy_vary)"),
("matmul", "cuda", "torch",
"torch.mm(dense_cuda_x, dense_cuda_y)"),
("matmul", "cuda", "torch.sparse",
"torch.sparse.mm(tx_cuda, ty_cuda)"),
]
else:
tasks = [
("backward", "cpu", "torch", "torch_backward(dense_x, dense_y)"),
("backward", "cpu", "torch.sparse",
"sparse_torch_backward(tx, ty)"),
("backward", "cuda", "torch",
"torch_backward(dense_cuda_x, dense_cuda_y)"),
("backward", "cuda", "torch.sparse",
"sparse_torch_backward(tx_cuda, ty_cuda)"),
]
serialized_results = []
repeats = 2
timers = [
benchmark_utils.Timer(
stmt=stmt,
globals={
"scipy_coo_matmul": scipy_coo_matmul,
"torch_backward": torch_backward,
"sparse_torch_backward": sparse_torch_backward,
"scipy_varx": to_coo_scipy(x),
"scipy_vary": to_coo_scipy(y),
"tx": x,
"ty": y,
"tx_cuda": x.cuda(),
"ty_cuda": y.cuda(),
"dense_cuda_x": x.to_dense().cuda(),
"dense_cuda_y": y.to_dense().cuda(),
"dense_x": x.to_dense(),
"dense_y": y.to_dense(),
},
label=label,
sub_label=sub_label,
description=f"{sparsity}",
env=device,
# num_threads=num_threads,
) for hidden_size in [512]
for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
for label, device, sub_label, stmt in tasks
for num_threads in [1, 4, 8, 16]
for x, y in load_dataset(dataset_path, hidden_size, sparsity)
]
measurements = []
for i, timer in enumerate(timers * repeats):
m = timer.blocked_autorange(min_run_time=0.05)
serialized_results.append(pickle.dumps(m))
m.metadata = {
"device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu'
}
measurements.append(m)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
print()
comparison = benchmark_utils.Compare(
[pickle.loads(i) for i in serialized_results])
print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
comparison.print()
print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
comparison.trim_significant_figures()
comparison.colorize()
comparison.print()
table = [(m.task_spec.sub_label, m.task_spec.description,
m.metadata["device"], m.mean) for m in measurements]
df = pd.DataFrame(table, columns=['method', 'sparsity', 'device', 'time'])
df.to_pickle(df_output_path)

View File

@ -1,14 +0,0 @@
#!/bin/bash
DATASET_ROOT_DIR=$HOME/datasets/
# wget https://storage.googleapis.com/sgk-sc2020/dlmc.tar.gz -P $DATASET_ROOT_DIR
# tar -xvf $DATASET_ROOT_DIR/dlmc.tar.gz
echo "!! SPARSE SPMS TIME BENCHMARK!! "
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation matmul --output /tmp/matmul_bench.pkl
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation backward --output /tmp/backward_bench.pkl
python plot_results.py -i /tmp/matmul_bench.pkl
python plot_results.py -i /tmp/backward_bench.pkl