mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
matmul performance benchmarks (#51647)
Summary: Minor PR following up the previous PR about sparse benchmarking utils https://github.com/pytorch/pytorch/pull/48397 Fixes https://github.com/pytorch/pytorch/issues/44634: Performance benchmarks for matrix-matrix and matrix-vector ops (dense-sparse, sparse-sparse, and compare to dense-dense) I ran all benchmarks on an 2xRTX8000 machine with AMD 2970WX 24-cores for `DLMC/magnitude_pruning` dataset with different sparsity levels. --- <details><summary> forward tests (expand for details). </summary> - `sparse@sparse` ``` [------------------------------- cpu:matmul-forward -------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------- torch:dense@dense | 108.1 | 100.5 | 101.3 | 108.4 | 98.4 | 187.4 torch:sparse@sparse | 659.1 | 368.8 | 156.5 | 53.3 | 26.8 | 14.9 scipy:sparse@sparse | 565.1 | 233.9 | 130.2 | 23.1 | 21.6 | 15.2 Times are in milliseconds (ms). [----------------------------------- cuda:matmul-forward -----------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ---------------------------------------------------------------------------------- torch:dense@dense | 2243.5 | 4392.5 | 4419.8 | 2272.3 | 4433.9 | 8920.1 torch:sparse@sparse | 21369.2 | 11877.6 | 7339.2 | 1787.2 | 1335.1 | 845.7 Times are in microseconds (us). ``` - `sparse@dense` ``` [------------------------------- cpu:matmul-forward -------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------- torch:dense@dense | 105.8 | 103.8 | 103.0 | 104.4 | 104.4 | 197.0 torch:sparse@dense | 119.9 | 102.4 | 84.0 | 19.7 | 16.8 | 11.6 scipy:sparse@dense | 906.5 | 799.6 | 697.8 | 182.2 | 165.5 | 135.4 Times are in milliseconds (ms). [------------------------- cuda:matmul-forward --------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: --------------------------------------------------------------- torch:dense@dense | 2.2 | 4.4 | 4.4 | 2.3 | 4.5 | 2.3 torch:sparse@dense | 5.7 | 6.6 | 4.5 | 1.4 | 1.4 | 1.3 Times are in milliseconds (ms). ``` - `sparse@vector` ``` [----------------------------------- cpu:matmul-forward ----------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: -------------------------------------------------------------------------------- torch:dense@vector | 510.6 | 505.8 | 759.6 | 782.1 | 682.4 | 764.6 torch:sparse@vector | 10122.8 | 6241.1 | 7935.6 | 2076.3 | 1049.5 | 826.3 scipy:sparse@vector | 1756.7 | 1033.9 | 678.2 | 343.5 | 168.5 | 65.4 Times are in microseconds (us). [-------------------------------- cuda:matmul-forward --------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ---------------------------------------------------------------------------- torch:dense@vector | 36.1 | 21.5 | 21.6 | 21.5 | 21.6 | 21.5 torch:sparse@vector | 1099.2 | 1289.4 | 775.7 | 327.1 | 285.4 | 274.0 Times are in microseconds (us). ``` </details> --- <details><summary> backward tests (expand for details). </summary> - `sparse@sparse` ``` [--------------------------------- cpu:matmul-backward ---------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------------ torch:dense@dense | 246.1 | 315.0 | 306.9 | 168.6 | 290.6 | 146.9 torch:sparse@sparse | 6417.5 | 4393.7 | 3012.7 | 1029.4 | 908.0 | 650.7 Times are in microseconds (us). [----------------------------- cuda:matmul-backward -----------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ----------------------------------------------------------------------- torch:dense@dense | 6.7 | 13.3 | 13.3 | 6.9 | 13.5 | 6.9 torch:sparse@sparse | 143.7 | 143.4 | 119.6 | 29.5 | 29.1 | 10.9 Times are in microseconds (us). ``` - `sparse@dense` ``` [------------------------------ cpu:matmul-backward -------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------- torch:dense@dense | 185.9 | 304.8 | 305.8 | 169.9 | 308.7 | 168.4 torch:sparse@dense | 407.9 | 345.8 | 274.6 | 114.2 | 163.6 | 230.5 Times are in milliseconds (ms). [--------------------------- cuda:matmul-backward --------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------ torch:dense@dense | 6.7 | 13.3 | 13.3 | 6.9 | 13.4 | 6.9 torch:sparse@dense | 16.7 | 19.0 | 15.1 | 6.3 | 8.2 | 12.7 Times are in milliseconds (ms). ``` </details> Kindly review this PR. cc mruberry, ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/51647 Reviewed By: albanD Differential Revision: D27007809 Pulled By: mruberry fbshipit-source-id: 8c1922cb3280027ca5e3eef31bfa20500c548cfd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
142c6b0e55
commit
39f50f468d
15
benchmarks/sparse/dlmc/README.md
Normal file
15
benchmarks/sparse/dlmc/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
# Sparse benchmarks
|
||||
|
||||
These sets of benchmarks are for the sparse matrix functionality using a popular real dataset collection called the Deep Learning Matrix Collection (DLMC), which were used in recent studies [1, 2].
|
||||
|
||||
Performance benchmarks scripts for matrix-matrix and matrix-vector ops (dense-sparse, sparse-sparse, and compare to dense-dense) are implemented here.
|
||||
|
||||
- `matmul_bench.py` with `--operation sparse@sparse|sparse@dense` is for Sparse matrix-matrix multiplication (SPMM) performance test. It can run in forward and backward mode with `--backward_test`, on CPU or CUDA with `--with_cuda`, using different datasets from the dataset collection DLMC. For more details see `test.sh` file.
|
||||
|
||||
- `matmul_bench.py` with `--operation sparse@vector` is for Sparse matrix-vector multiplication (SPMV) performance test.
|
||||
|
||||
References:
|
||||
|
||||
1. Trevor Gale, Matei Zaharia, Cliff Young, Erich Elsen. Sparse GPU Kernels for Deep Learning. Proceedings of the International Conference for High Performance Computing, 2020. https://github.com/google-research/google-research/tree/master/sgk
|
||||
|
||||
2. Trevor Gale, Erich Elsen, Sara Hooker. The State of Sparsity in Deep Neural Networks. https://github.com/google-research/google-research/tree/master/state_of_sparsity
|
3
benchmarks/sparse/dlmc/__init__.py
Normal file
3
benchmarks/sparse/dlmc/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
126
benchmarks/sparse/dlmc/matmul_bench.py
Normal file
126
benchmarks/sparse/dlmc/matmul_bench.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Sparse benchmarks
|
||||
|
||||
# This benchmark is for sparse matmul performance test.
|
||||
# They exist for comparing the performance of sparse matrix routines
|
||||
# `sparse @ vector`, `sparse @ sparse` and `sparse @ dense` with different backends (CPU/CUDA)
|
||||
# and with other frameworks such as scipy.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark_utils
|
||||
from .utils import load_dlmc_dataset
|
||||
from scipy.sparse import isspmatrix
|
||||
import os
|
||||
|
||||
|
||||
def scipy_matmul(mat1, mat2):
|
||||
if isspmatrix(mat1) and isspmatrix(mat2):
|
||||
return mat1.dot(mat2).tocoo()
|
||||
return mat1.dot(mat2)
|
||||
|
||||
def matmul_backward(a_dense, b_dense, grad_output):
|
||||
r1 = a_dense.matmul(b_dense)
|
||||
r1.backward(grad_output)
|
||||
|
||||
|
||||
def sparse_matmul_backward(a, b, grad_output):
|
||||
c = torch.sparse.mm(a, b)
|
||||
c.backward(grad_output)
|
||||
|
||||
|
||||
OPS_MAP = {
|
||||
"sparse@sparse": "torch.sparse.mm",
|
||||
"sparse@dense": "torch.matmul",
|
||||
"sparse@vector": "torch.matmul",
|
||||
}
|
||||
|
||||
|
||||
# also get the arguments as input from the user using `argparse`
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='matmul benchmark')
|
||||
parser.add_argument('--path', type=str, help='DLMC dataset path')
|
||||
parser.add_argument('--dataset', type=str, default='magnitude_pruning')
|
||||
parser.add_argument('--hidden_size', default=2048, type=int)
|
||||
parser.add_argument('--backward_test', action="store_true")
|
||||
parser.add_argument('--operation', type=str, help="|".join(OPS_MAP.keys()), default=next(iter(OPS_MAP)))
|
||||
parser.add_argument('--with_cuda', action='store_true')
|
||||
parser.add_argument('--timer_min_run_time', default=1, type=float)
|
||||
return parser
|
||||
|
||||
|
||||
def get_tasks(op, backward_test, device):
|
||||
def filter_ops(operation):
|
||||
if backward_test:
|
||||
test_name = device + ":matmul-backward"
|
||||
return [
|
||||
(test_name, device, "torch:" + operation.replace("sparse", "dense"),
|
||||
"matmul_backward(dx, dy, grad_output)"),
|
||||
(test_name, device, "torch:" + operation, "sparse_matmul_backward(x, y, sparse_grad_output)")
|
||||
]
|
||||
else:
|
||||
test_name = device + ":matmul-forward"
|
||||
return list(filter(None, [
|
||||
(test_name, device, "torch:" + operation.replace("sparse", "dense"),
|
||||
"{}(dx, dy)".format(OPS_MAP[operation])),
|
||||
(test_name, device, "torch:" + operation, "{}(x, y)".format(OPS_MAP[operation])),
|
||||
(test_name, device, "scipy:" + operation, "scipy_matmul(sx, sy)") if device == "cpu" else None
|
||||
]))
|
||||
|
||||
all_operations = {
|
||||
"sparse@sparse": filter_ops("sparse@sparse"),
|
||||
"sparse@dense": filter_ops("sparse@dense"),
|
||||
"sparse@vector": filter_ops("sparse@vector"),
|
||||
}
|
||||
return all_operations[op]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.with_cuda and not torch.cuda.is_available():
|
||||
raise RuntimeError("No CUDA available")
|
||||
|
||||
dataset_path = args.path
|
||||
dataset_name = args.dataset
|
||||
dataset_path = os.path.join(dataset_path, dataset_name)
|
||||
device = 'cuda' if args.with_cuda else 'cpu'
|
||||
|
||||
tasks = get_tasks(args.operation, args.backward_test, device)
|
||||
repeats = 3
|
||||
timers = [
|
||||
benchmark_utils.Timer(
|
||||
stmt=stmt,
|
||||
globals={
|
||||
"scipy_matmul": scipy_matmul,
|
||||
"matmul_backward": matmul_backward,
|
||||
"sparse_matmul_backward": sparse_matmul_backward,
|
||||
**variables
|
||||
},
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=f"{sparsity}",
|
||||
env=device,
|
||||
)
|
||||
for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
|
||||
for label, device, sub_label, stmt in tasks
|
||||
for variables in
|
||||
load_dlmc_dataset(dataset_path, args.operation, args.hidden_size, sparsity, device, args.backward_test)
|
||||
]
|
||||
measurements = []
|
||||
|
||||
for i, timer in enumerate(timers * repeats):
|
||||
m = timer.blocked_autorange(min_run_time=args.timer_min_run_time)
|
||||
m.metadata = {
|
||||
"device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu'
|
||||
}
|
||||
measurements.append(m)
|
||||
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
|
||||
comparison = benchmark_utils.Compare(measurements)
|
||||
|
||||
print("== Results " + "=" * 80 + "\n" + "/" * 95 + "\n")
|
||||
comparison.print()
|
27
benchmarks/sparse/dlmc/test.sh
Normal file
27
benchmarks/sparse/dlmc/test.sh
Normal file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
DATASET_ROOT_DIR=$HOME/datasets/
|
||||
|
||||
# wget https://storage.googleapis.com/sgk-sc2020/dlmc.tar.gz -P $DATASET_ROOT_DIR
|
||||
# tar -xvf $DATASET_ROOT_DIR/dlmc.tar.gz
|
||||
|
||||
echo "!! SPARSE SPMS TIME BENCHMARK!! "
|
||||
|
||||
# cpu
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse --backward_test
|
||||
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense --backward_test
|
||||
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@vector
|
||||
|
||||
|
||||
# cuda
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse --with_cuda
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@sparse --with_cuda--backward_test
|
||||
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense --with_cuda
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@dense --with_cuda --backward_test
|
||||
|
||||
python -m dlmc.matmul_bench --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset magnitude_pruning --operation sparse@vector --with_cuda
|
199
benchmarks/sparse/dlmc/utils.py
Normal file
199
benchmarks/sparse/dlmc/utils.py
Normal file
@ -0,0 +1,199 @@
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from scipy import sparse
|
||||
import math
|
||||
|
||||
|
||||
def to_coo_scipy(x):
|
||||
indices_1 = x._indices().numpy()
|
||||
values_1 = x._values().numpy()
|
||||
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
|
||||
shape=x.shape)
|
||||
|
||||
|
||||
def sparse_grad_output(a, b):
|
||||
c = torch.sparse.mm(a, b)
|
||||
if c.is_sparse:
|
||||
c2 = torch.rand_like(c.to_dense())
|
||||
return c2.sparse_mask(c.coalesce())
|
||||
else:
|
||||
return torch.rand_like(c)
|
||||
|
||||
|
||||
def read_matrix_params(path):
|
||||
with open(path, 'r') as file:
|
||||
line = file.readline()
|
||||
nrows, ncols, nnz = map(lambda el: int(el), line.split(', '))
|
||||
return (nrows, ncols), nnz
|
||||
|
||||
|
||||
def csr_to_coo(indices, indptr, shape):
|
||||
n_rows, n_cols = shape
|
||||
cols = indices
|
||||
rows = [0] * len(cols)
|
||||
for i in range(n_rows):
|
||||
for j in range(indptr[i], indptr[i + 1]):
|
||||
rows[j] = i
|
||||
return torch.tensor([rows, cols], dtype=torch.long)
|
||||
|
||||
|
||||
def load_sparse_matrix(path, device):
|
||||
with open(path, 'r') as file:
|
||||
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
||||
index_pointers = map(lambda el: int(el), file.readline().split())
|
||||
indices = map(lambda el: int(el), file.readline().split())
|
||||
|
||||
index_pointers = list(index_pointers)
|
||||
indices = list(indices)
|
||||
data = torch.randn(nnz, dtype=torch.double)
|
||||
shape = (nrows, ncols)
|
||||
return torch.sparse_coo_tensor(csr_to_coo(indices, index_pointers, shape), data, shape, device=device)
|
||||
|
||||
|
||||
def gen_vector(path, device):
|
||||
with open(path, 'r') as file:
|
||||
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
||||
index_pointers = map(lambda el: int(el), file.readline().split())
|
||||
indices = map(lambda el: int(el), file.readline().split())
|
||||
return torch.randn(nrows, dtype=torch.double, device=device)
|
||||
|
||||
|
||||
def gen_matrix(path, device):
|
||||
with open(path, 'r') as file:
|
||||
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
||||
index_pointers = map(lambda el: int(el), file.readline().split())
|
||||
indices = map(lambda el: int(el), file.readline().split())
|
||||
return torch.randn(nrows, ncols, dtype=torch.double, device=device)
|
||||
|
||||
|
||||
def load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit=math.inf):
|
||||
"""load_spmv_dataset loads a DLMC dataset for a sparse matrix-vector multiplication (SPMV) performance test.
|
||||
Args:
|
||||
dataset_path:
|
||||
path of the dataset from DLMC collection.
|
||||
hidden_size
|
||||
This value allows tensors of varying sizes.
|
||||
sparsity:
|
||||
This value allows tensors of varying sparsities.
|
||||
device:
|
||||
Whether to place the Tensor on a GPU or CPU.
|
||||
n_limit:
|
||||
This value allows a dataset with some limit size.
|
||||
"""
|
||||
current_folder_path = f"{dataset_path}/{sparsity}"
|
||||
path = Path(current_folder_path)
|
||||
files = path.glob('**/*.smtx')
|
||||
print(dataset_path, hidden_size, sparsity)
|
||||
index = 0
|
||||
x_files, y_files = [], []
|
||||
for f in files:
|
||||
if index >= n_limit:
|
||||
break
|
||||
print('.', end='')
|
||||
size, nnz = read_matrix_params(f.as_posix())
|
||||
if size[1] == hidden_size:
|
||||
x_files.append(f.as_posix())
|
||||
if size[0] == hidden_size:
|
||||
y_files.append(f.as_posix())
|
||||
index += 1
|
||||
print()
|
||||
|
||||
for fx, fy in zip(x_files, y_files):
|
||||
x = load_sparse_matrix(fx, device)
|
||||
y = gen_vector(fy, device)
|
||||
yield (x, y)
|
||||
|
||||
|
||||
def load_spmm_dataset(dataset_path, hidden_size, sparsity, spmm_type, device, n_limit=math.inf):
|
||||
"""load_spmm_dataset loads a DLMC dataset for a sparse matrix-matrix multiplication (SPMM) performance test.
|
||||
Args:
|
||||
dataset_path:
|
||||
path of the dataset from DLMC collection.
|
||||
hidden_size
|
||||
This value allows tensors of varying sizes.
|
||||
sparsity:
|
||||
This value allows tensors of varying sparsities.
|
||||
spmm_type:
|
||||
This value allows tensors for `sparse@sparse` or `sparse@dense` operations.
|
||||
device:
|
||||
Whether to place the Tensor on a GPU or CPU.
|
||||
n_limit:
|
||||
This value allows a dataset with some limit size.
|
||||
"""
|
||||
current_folder_path = f"{dataset_path}/{sparsity}"
|
||||
path = Path(current_folder_path)
|
||||
files = path.glob('**/*.smtx')
|
||||
print(dataset_path, hidden_size, sparsity)
|
||||
index = 0
|
||||
x_files, y_files = [], []
|
||||
for f in files:
|
||||
if index >= n_limit:
|
||||
break
|
||||
print('.', end='')
|
||||
size, nnz = read_matrix_params(f.as_posix())
|
||||
if size[1] == hidden_size:
|
||||
x_files.append(f.as_posix())
|
||||
if size[0] == hidden_size:
|
||||
y_files.append(f.as_posix())
|
||||
index += 1
|
||||
print()
|
||||
|
||||
for fx, fy in zip(x_files, y_files):
|
||||
x = load_sparse_matrix(fx, device)
|
||||
y = gen_matrix(fy, device) if spmm_type == 'sparse@dense' else load_sparse_matrix(fy, device)
|
||||
yield (x, y)
|
||||
|
||||
|
||||
def load_dlmc_dataset(dataset_path, operation, hidden_size, sparsity, device, requires_grad, n_limit=math.inf):
|
||||
"""load_dlmc_dataset loads a DLMC dataset for a matmul performance test.
|
||||
Args:
|
||||
dataset_path:
|
||||
path of the dataset from DLMC collection.
|
||||
operation:
|
||||
This value allows tensors for `sparse@sparse`|`sparse@dense`|`sparse@vector` operations.
|
||||
hidden_size
|
||||
This value allows tensors of varying sizes.
|
||||
sparsity:
|
||||
This value allows tensors of varying sparsities.
|
||||
device:
|
||||
Whether to place the Tensor on a GPU or CPU.
|
||||
requires_grad:
|
||||
Loads the dataset for backward test.
|
||||
n_limit:
|
||||
This value allows a dataset with some limit size.
|
||||
"""
|
||||
if operation == 'sparse@sparse' or operation == "sparse@dense":
|
||||
collection = load_spmm_dataset(dataset_path, hidden_size, sparsity, operation, device, n_limit)
|
||||
elif operation == 'sparse@vector':
|
||||
collection = load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit)
|
||||
scipy_vars = {}
|
||||
backward_vars = {}
|
||||
for x, y in collection:
|
||||
if device == 'cpu':
|
||||
scipy_vars = {
|
||||
"sx": to_coo_scipy(x) if x.is_sparse else x.numpy(),
|
||||
"sy": to_coo_scipy(y) if y.is_sparse else y.numpy(),
|
||||
}
|
||||
if not requires_grad:
|
||||
dx = x.to_dense() if x.is_sparse else x
|
||||
dy = y.to_dense() if y.is_sparse else y
|
||||
else:
|
||||
c = sparse_grad_output(x, y)
|
||||
backward_vars = {
|
||||
"sparse_grad_output": c,
|
||||
"grad_output": c.to_dense() if c.is_sparse else c,
|
||||
}
|
||||
x.requires_grad_(True)
|
||||
y.requires_grad_(True)
|
||||
dx = x.to_dense().detach() if x.is_sparse else x.clone().detach()
|
||||
dy = y.to_dense().detach() if y.is_sparse else y.clone().detach()
|
||||
dx.requires_grad_(True)
|
||||
dy.requires_grad_(True)
|
||||
yield {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"dx": dx,
|
||||
"dy": dy,
|
||||
**scipy_vars,
|
||||
**backward_vars
|
||||
}
|
@ -1,198 +0,0 @@
|
||||
# Sparse benchmarks
|
||||
|
||||
# These benchmarks are for the sparse matrix functionality.
|
||||
# They exist for comparing the performance of sparse matrix routines
|
||||
# torch.sparse.mm(sparse, sparse)` with different backends (CPU/CUDA)
|
||||
# and with other frameworks such as scipy.
|
||||
|
||||
import sys
|
||||
from scipy import sparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import argparse
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark_utils
|
||||
|
||||
def read_matrix_params(path):
|
||||
sys.stdin = open(path)
|
||||
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
|
||||
return (nrows, ncols), nnz
|
||||
|
||||
|
||||
def load_matrix(path):
|
||||
sys.stdin = open(path)
|
||||
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
|
||||
index_pointers = map(lambda el: int(el), input().split())
|
||||
indices = map(lambda el: int(el), input().split())
|
||||
|
||||
index_pointers = list(index_pointers)
|
||||
indices = list(indices)
|
||||
data = np.random.rand(nnz)
|
||||
coo = sparse.csr_matrix(
|
||||
(data, np.array(indices), np.array(index_pointers)),
|
||||
shape=(nrows, ncols)).tocoo()
|
||||
return torch.sparse_coo_tensor([coo.row, coo.col], coo.data, coo.shape)
|
||||
|
||||
|
||||
def scipy_coo_matmul(mat1, mat2):
|
||||
result = mat1.dot(mat2).tocoo()
|
||||
return torch.sparse_coo_tensor([result.row, result.col], result.data,
|
||||
result.shape)
|
||||
|
||||
|
||||
def to_coo_scipy(x):
|
||||
indices_1 = x._indices().numpy()
|
||||
values_1 = x._values().numpy()
|
||||
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
|
||||
shape=x.shape)
|
||||
|
||||
|
||||
def torch_backward(a_dense, b_dense):
|
||||
a_dense.requires_grad = True
|
||||
b_dense.requires_grad = True
|
||||
r1 = a_dense.matmul(b_dense)
|
||||
f1 = torch.sum(r1)
|
||||
f1.backward()
|
||||
|
||||
|
||||
def sparse_torch_backward(a, b):
|
||||
a.requires_grad = True
|
||||
b.requires_grad = True
|
||||
|
||||
r2 = torch.sparse.mm(a, b)
|
||||
f2 = torch.sparse.sum(r2)
|
||||
f2.backward()
|
||||
|
||||
|
||||
def load_dataset(dataset_path, hidden_size, sparsity, n_limit=20):
|
||||
current_folder_path = f"{dataset_path}/{sparsity}"
|
||||
path = Path(current_folder_path)
|
||||
files = path.glob('**/*.smtx')
|
||||
xs = []
|
||||
ys = []
|
||||
print(dataset_path, hidden_size, sparsity)
|
||||
index = 0
|
||||
for elem in files:
|
||||
if index == n_limit:
|
||||
break
|
||||
print('.', end='')
|
||||
size, nnz = read_matrix_params(elem.as_posix())
|
||||
if size[1] == hidden_size:
|
||||
xs.append(load_matrix(elem.as_posix()))
|
||||
if size[0] == hidden_size:
|
||||
ys.append(load_matrix(elem.as_posix()))
|
||||
index += 1
|
||||
print()
|
||||
return zip(xs, ys)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
path = Path()
|
||||
parser = argparse.ArgumentParser(description='Sparse Matmul Bench')
|
||||
|
||||
parser.add_argument('--path', type=str, help='dataset path')
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
help='dataset name',
|
||||
default='random_pruning')
|
||||
parser.add_argument('--operation',
|
||||
type=str,
|
||||
help='matmul or backward',
|
||||
default='matmul')
|
||||
parser.add_argument('--output',
|
||||
type=str,
|
||||
help='dataframe output path',
|
||||
default='/tmp/matmul_bench.pkl')
|
||||
args = parser.parse_args()
|
||||
print('path =', args.path)
|
||||
print('dataset =', args.dataset)
|
||||
print('operation =', args.operation)
|
||||
print('output =', args.output)
|
||||
|
||||
dataset_path = args.path
|
||||
dataset_name = args.dataset
|
||||
dataset_path = f"{dataset_path}/{dataset_name}"
|
||||
df_output_path = args.output
|
||||
tasks = []
|
||||
if args.operation == 'matmul':
|
||||
tasks = [
|
||||
("matmul", "cpu", "torch", "torch.mm(dense_x, dense_y)"),
|
||||
("matmul", "cpu", "torch.sparse", "torch.sparse.mm(tx, ty)"),
|
||||
("matmul", "cpu", "scipy",
|
||||
"scipy_coo_matmul(scipy_varx, scipy_vary)"),
|
||||
("matmul", "cuda", "torch",
|
||||
"torch.mm(dense_cuda_x, dense_cuda_y)"),
|
||||
("matmul", "cuda", "torch.sparse",
|
||||
"torch.sparse.mm(tx_cuda, ty_cuda)"),
|
||||
]
|
||||
else:
|
||||
tasks = [
|
||||
("backward", "cpu", "torch", "torch_backward(dense_x, dense_y)"),
|
||||
("backward", "cpu", "torch.sparse",
|
||||
"sparse_torch_backward(tx, ty)"),
|
||||
("backward", "cuda", "torch",
|
||||
"torch_backward(dense_cuda_x, dense_cuda_y)"),
|
||||
("backward", "cuda", "torch.sparse",
|
||||
"sparse_torch_backward(tx_cuda, ty_cuda)"),
|
||||
]
|
||||
serialized_results = []
|
||||
repeats = 2
|
||||
timers = [
|
||||
benchmark_utils.Timer(
|
||||
stmt=stmt,
|
||||
globals={
|
||||
"scipy_coo_matmul": scipy_coo_matmul,
|
||||
"torch_backward": torch_backward,
|
||||
"sparse_torch_backward": sparse_torch_backward,
|
||||
"scipy_varx": to_coo_scipy(x),
|
||||
"scipy_vary": to_coo_scipy(y),
|
||||
"tx": x,
|
||||
"ty": y,
|
||||
"tx_cuda": x.cuda(),
|
||||
"ty_cuda": y.cuda(),
|
||||
"dense_cuda_x": x.to_dense().cuda(),
|
||||
"dense_cuda_y": y.to_dense().cuda(),
|
||||
"dense_x": x.to_dense(),
|
||||
"dense_y": y.to_dense(),
|
||||
},
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=f"{sparsity}",
|
||||
env=device,
|
||||
# num_threads=num_threads,
|
||||
) for hidden_size in [512]
|
||||
for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
|
||||
for label, device, sub_label, stmt in tasks
|
||||
for num_threads in [1, 4, 8, 16]
|
||||
for x, y in load_dataset(dataset_path, hidden_size, sparsity)
|
||||
]
|
||||
measurements = []
|
||||
|
||||
for i, timer in enumerate(timers * repeats):
|
||||
m = timer.blocked_autorange(min_run_time=0.05)
|
||||
serialized_results.append(pickle.dumps(m))
|
||||
m.metadata = {
|
||||
"device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu'
|
||||
}
|
||||
measurements.append(m)
|
||||
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
|
||||
comparison = benchmark_utils.Compare(
|
||||
[pickle.loads(i) for i in serialized_results])
|
||||
|
||||
print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
|
||||
comparison.print()
|
||||
|
||||
print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
|
||||
comparison.trim_significant_figures()
|
||||
comparison.colorize()
|
||||
comparison.print()
|
||||
|
||||
table = [(m.task_spec.sub_label, m.task_spec.description,
|
||||
m.metadata["device"], m.mean) for m in measurements]
|
||||
df = pd.DataFrame(table, columns=['method', 'sparsity', 'device', 'time'])
|
||||
df.to_pickle(df_output_path)
|
@ -1,14 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
DATASET_ROOT_DIR=$HOME/datasets/
|
||||
|
||||
# wget https://storage.googleapis.com/sgk-sc2020/dlmc.tar.gz -P $DATASET_ROOT_DIR
|
||||
# tar -xvf $DATASET_ROOT_DIR/dlmc.tar.gz
|
||||
|
||||
echo "!! SPARSE SPMS TIME BENCHMARK!! "
|
||||
|
||||
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation matmul --output /tmp/matmul_bench.pkl
|
||||
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation backward --output /tmp/backward_bench.pkl
|
||||
|
||||
python plot_results.py -i /tmp/matmul_bench.pkl
|
||||
python plot_results.py -i /tmp/backward_bench.pkl
|
Reference in New Issue
Block a user