mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Minor PR following up the previous PR about sparse benchmarking utils https://github.com/pytorch/pytorch/pull/48397 Fixes https://github.com/pytorch/pytorch/issues/44634: Performance benchmarks for matrix-matrix and matrix-vector ops (dense-sparse, sparse-sparse, and compare to dense-dense) I ran all benchmarks on an 2xRTX8000 machine with AMD 2970WX 24-cores for `DLMC/magnitude_pruning` dataset with different sparsity levels. --- <details><summary> forward tests (expand for details). </summary> - `sparse@sparse` ``` [------------------------------- cpu:matmul-forward -------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------- torch:dense@dense | 108.1 | 100.5 | 101.3 | 108.4 | 98.4 | 187.4 torch:sparse@sparse | 659.1 | 368.8 | 156.5 | 53.3 | 26.8 | 14.9 scipy:sparse@sparse | 565.1 | 233.9 | 130.2 | 23.1 | 21.6 | 15.2 Times are in milliseconds (ms). [----------------------------------- cuda:matmul-forward -----------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ---------------------------------------------------------------------------------- torch:dense@dense | 2243.5 | 4392.5 | 4419.8 | 2272.3 | 4433.9 | 8920.1 torch:sparse@sparse | 21369.2 | 11877.6 | 7339.2 | 1787.2 | 1335.1 | 845.7 Times are in microseconds (us). ``` - `sparse@dense` ``` [------------------------------- cpu:matmul-forward -------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------- torch:dense@dense | 105.8 | 103.8 | 103.0 | 104.4 | 104.4 | 197.0 torch:sparse@dense | 119.9 | 102.4 | 84.0 | 19.7 | 16.8 | 11.6 scipy:sparse@dense | 906.5 | 799.6 | 697.8 | 182.2 | 165.5 | 135.4 Times are in milliseconds (ms). [------------------------- cuda:matmul-forward --------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: --------------------------------------------------------------- torch:dense@dense | 2.2 | 4.4 | 4.4 | 2.3 | 4.5 | 2.3 torch:sparse@dense | 5.7 | 6.6 | 4.5 | 1.4 | 1.4 | 1.3 Times are in milliseconds (ms). ``` - `sparse@vector` ``` [----------------------------------- cpu:matmul-forward ----------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: -------------------------------------------------------------------------------- torch:dense@vector | 510.6 | 505.8 | 759.6 | 782.1 | 682.4 | 764.6 torch:sparse@vector | 10122.8 | 6241.1 | 7935.6 | 2076.3 | 1049.5 | 826.3 scipy:sparse@vector | 1756.7 | 1033.9 | 678.2 | 343.5 | 168.5 | 65.4 Times are in microseconds (us). [-------------------------------- cuda:matmul-forward --------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ---------------------------------------------------------------------------- torch:dense@vector | 36.1 | 21.5 | 21.6 | 21.5 | 21.6 | 21.5 torch:sparse@vector | 1099.2 | 1289.4 | 775.7 | 327.1 | 285.4 | 274.0 Times are in microseconds (us). ``` </details> --- <details><summary> backward tests (expand for details). </summary> - `sparse@sparse` ``` [--------------------------------- cpu:matmul-backward ---------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------------ torch:dense@dense | 246.1 | 315.0 | 306.9 | 168.6 | 290.6 | 146.9 torch:sparse@sparse | 6417.5 | 4393.7 | 3012.7 | 1029.4 | 908.0 | 650.7 Times are in microseconds (us). [----------------------------- cuda:matmul-backward -----------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ----------------------------------------------------------------------- torch:dense@dense | 6.7 | 13.3 | 13.3 | 6.9 | 13.5 | 6.9 torch:sparse@sparse | 143.7 | 143.4 | 119.6 | 29.5 | 29.1 | 10.9 Times are in microseconds (us). ``` - `sparse@dense` ``` [------------------------------ cpu:matmul-backward -------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------------- torch:dense@dense | 185.9 | 304.8 | 305.8 | 169.9 | 308.7 | 168.4 torch:sparse@dense | 407.9 | 345.8 | 274.6 | 114.2 | 163.6 | 230.5 Times are in milliseconds (ms). [--------------------------- cuda:matmul-backward --------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------ torch:dense@dense | 6.7 | 13.3 | 13.3 | 6.9 | 13.4 | 6.9 torch:sparse@dense | 16.7 | 19.0 | 15.1 | 6.3 | 8.2 | 12.7 Times are in milliseconds (ms). ``` </details> Kindly review this PR. cc mruberry, ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/51647 Reviewed By: albanD Differential Revision: D27007809 Pulled By: mruberry fbshipit-source-id: 8c1922cb3280027ca5e3eef31bfa20500c548cfd
200 lines
7.0 KiB
Python
200 lines
7.0 KiB
Python
import torch
|
|
from pathlib import Path
|
|
from scipy import sparse
|
|
import math
|
|
|
|
|
|
def to_coo_scipy(x):
|
|
indices_1 = x._indices().numpy()
|
|
values_1 = x._values().numpy()
|
|
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
|
|
shape=x.shape)
|
|
|
|
|
|
def sparse_grad_output(a, b):
|
|
c = torch.sparse.mm(a, b)
|
|
if c.is_sparse:
|
|
c2 = torch.rand_like(c.to_dense())
|
|
return c2.sparse_mask(c.coalesce())
|
|
else:
|
|
return torch.rand_like(c)
|
|
|
|
|
|
def read_matrix_params(path):
|
|
with open(path, 'r') as file:
|
|
line = file.readline()
|
|
nrows, ncols, nnz = map(lambda el: int(el), line.split(', '))
|
|
return (nrows, ncols), nnz
|
|
|
|
|
|
def csr_to_coo(indices, indptr, shape):
|
|
n_rows, n_cols = shape
|
|
cols = indices
|
|
rows = [0] * len(cols)
|
|
for i in range(n_rows):
|
|
for j in range(indptr[i], indptr[i + 1]):
|
|
rows[j] = i
|
|
return torch.tensor([rows, cols], dtype=torch.long)
|
|
|
|
|
|
def load_sparse_matrix(path, device):
|
|
with open(path, 'r') as file:
|
|
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
|
index_pointers = map(lambda el: int(el), file.readline().split())
|
|
indices = map(lambda el: int(el), file.readline().split())
|
|
|
|
index_pointers = list(index_pointers)
|
|
indices = list(indices)
|
|
data = torch.randn(nnz, dtype=torch.double)
|
|
shape = (nrows, ncols)
|
|
return torch.sparse_coo_tensor(csr_to_coo(indices, index_pointers, shape), data, shape, device=device)
|
|
|
|
|
|
def gen_vector(path, device):
|
|
with open(path, 'r') as file:
|
|
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
|
index_pointers = map(lambda el: int(el), file.readline().split())
|
|
indices = map(lambda el: int(el), file.readline().split())
|
|
return torch.randn(nrows, dtype=torch.double, device=device)
|
|
|
|
|
|
def gen_matrix(path, device):
|
|
with open(path, 'r') as file:
|
|
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
|
index_pointers = map(lambda el: int(el), file.readline().split())
|
|
indices = map(lambda el: int(el), file.readline().split())
|
|
return torch.randn(nrows, ncols, dtype=torch.double, device=device)
|
|
|
|
|
|
def load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit=math.inf):
|
|
"""load_spmv_dataset loads a DLMC dataset for a sparse matrix-vector multiplication (SPMV) performance test.
|
|
Args:
|
|
dataset_path:
|
|
path of the dataset from DLMC collection.
|
|
hidden_size
|
|
This value allows tensors of varying sizes.
|
|
sparsity:
|
|
This value allows tensors of varying sparsities.
|
|
device:
|
|
Whether to place the Tensor on a GPU or CPU.
|
|
n_limit:
|
|
This value allows a dataset with some limit size.
|
|
"""
|
|
current_folder_path = f"{dataset_path}/{sparsity}"
|
|
path = Path(current_folder_path)
|
|
files = path.glob('**/*.smtx')
|
|
print(dataset_path, hidden_size, sparsity)
|
|
index = 0
|
|
x_files, y_files = [], []
|
|
for f in files:
|
|
if index >= n_limit:
|
|
break
|
|
print('.', end='')
|
|
size, nnz = read_matrix_params(f.as_posix())
|
|
if size[1] == hidden_size:
|
|
x_files.append(f.as_posix())
|
|
if size[0] == hidden_size:
|
|
y_files.append(f.as_posix())
|
|
index += 1
|
|
print()
|
|
|
|
for fx, fy in zip(x_files, y_files):
|
|
x = load_sparse_matrix(fx, device)
|
|
y = gen_vector(fy, device)
|
|
yield (x, y)
|
|
|
|
|
|
def load_spmm_dataset(dataset_path, hidden_size, sparsity, spmm_type, device, n_limit=math.inf):
|
|
"""load_spmm_dataset loads a DLMC dataset for a sparse matrix-matrix multiplication (SPMM) performance test.
|
|
Args:
|
|
dataset_path:
|
|
path of the dataset from DLMC collection.
|
|
hidden_size
|
|
This value allows tensors of varying sizes.
|
|
sparsity:
|
|
This value allows tensors of varying sparsities.
|
|
spmm_type:
|
|
This value allows tensors for `sparse@sparse` or `sparse@dense` operations.
|
|
device:
|
|
Whether to place the Tensor on a GPU or CPU.
|
|
n_limit:
|
|
This value allows a dataset with some limit size.
|
|
"""
|
|
current_folder_path = f"{dataset_path}/{sparsity}"
|
|
path = Path(current_folder_path)
|
|
files = path.glob('**/*.smtx')
|
|
print(dataset_path, hidden_size, sparsity)
|
|
index = 0
|
|
x_files, y_files = [], []
|
|
for f in files:
|
|
if index >= n_limit:
|
|
break
|
|
print('.', end='')
|
|
size, nnz = read_matrix_params(f.as_posix())
|
|
if size[1] == hidden_size:
|
|
x_files.append(f.as_posix())
|
|
if size[0] == hidden_size:
|
|
y_files.append(f.as_posix())
|
|
index += 1
|
|
print()
|
|
|
|
for fx, fy in zip(x_files, y_files):
|
|
x = load_sparse_matrix(fx, device)
|
|
y = gen_matrix(fy, device) if spmm_type == 'sparse@dense' else load_sparse_matrix(fy, device)
|
|
yield (x, y)
|
|
|
|
|
|
def load_dlmc_dataset(dataset_path, operation, hidden_size, sparsity, device, requires_grad, n_limit=math.inf):
|
|
"""load_dlmc_dataset loads a DLMC dataset for a matmul performance test.
|
|
Args:
|
|
dataset_path:
|
|
path of the dataset from DLMC collection.
|
|
operation:
|
|
This value allows tensors for `sparse@sparse`|`sparse@dense`|`sparse@vector` operations.
|
|
hidden_size
|
|
This value allows tensors of varying sizes.
|
|
sparsity:
|
|
This value allows tensors of varying sparsities.
|
|
device:
|
|
Whether to place the Tensor on a GPU or CPU.
|
|
requires_grad:
|
|
Loads the dataset for backward test.
|
|
n_limit:
|
|
This value allows a dataset with some limit size.
|
|
"""
|
|
if operation == 'sparse@sparse' or operation == "sparse@dense":
|
|
collection = load_spmm_dataset(dataset_path, hidden_size, sparsity, operation, device, n_limit)
|
|
elif operation == 'sparse@vector':
|
|
collection = load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit)
|
|
scipy_vars = {}
|
|
backward_vars = {}
|
|
for x, y in collection:
|
|
if device == 'cpu':
|
|
scipy_vars = {
|
|
"sx": to_coo_scipy(x) if x.is_sparse else x.numpy(),
|
|
"sy": to_coo_scipy(y) if y.is_sparse else y.numpy(),
|
|
}
|
|
if not requires_grad:
|
|
dx = x.to_dense() if x.is_sparse else x
|
|
dy = y.to_dense() if y.is_sparse else y
|
|
else:
|
|
c = sparse_grad_output(x, y)
|
|
backward_vars = {
|
|
"sparse_grad_output": c,
|
|
"grad_output": c.to_dense() if c.is_sparse else c,
|
|
}
|
|
x.requires_grad_(True)
|
|
y.requires_grad_(True)
|
|
dx = x.to_dense().detach() if x.is_sparse else x.clone().detach()
|
|
dy = y.to_dense().detach() if y.is_sparse else y.clone().detach()
|
|
dx.requires_grad_(True)
|
|
dy.requires_grad_(True)
|
|
yield {
|
|
"x": x,
|
|
"y": y,
|
|
"dx": dx,
|
|
"dy": dy,
|
|
**scipy_vars,
|
|
**backward_vars
|
|
}
|