mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Lint rules that tells the user to avoid keeping track of their own counter and use the builtin enumerate when possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147290 Approved by: https://github.com/jansel
210 lines
6.8 KiB
Python
210 lines
6.8 KiB
Python
import math
|
|
from pathlib import Path
|
|
|
|
from scipy import sparse
|
|
|
|
import torch
|
|
|
|
|
|
def to_coo_scipy(x):
|
|
indices_1 = x._indices().numpy()
|
|
values_1 = x._values().numpy()
|
|
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])), shape=x.shape)
|
|
|
|
|
|
def sparse_grad_output(a, b):
|
|
c = torch.sparse.mm(a, b)
|
|
if c.is_sparse:
|
|
c2 = torch.rand_like(c.to_dense())
|
|
return c2.sparse_mask(c.coalesce())
|
|
else:
|
|
return torch.rand_like(c)
|
|
|
|
|
|
def read_matrix_params(path):
|
|
with open(path) as file:
|
|
line = file.readline()
|
|
nrows, ncols, nnz = (int(el) for el in line.split(", "))
|
|
return (nrows, ncols), nnz
|
|
|
|
|
|
def csr_to_coo(indices, indptr, shape):
|
|
n_rows, n_cols = shape
|
|
cols = indices
|
|
rows = [0] * len(cols)
|
|
for i in range(n_rows):
|
|
for j in range(indptr[i], indptr[i + 1]):
|
|
rows[j] = i
|
|
return torch.tensor([rows, cols], dtype=torch.long)
|
|
|
|
|
|
def load_sparse_matrix(path, device):
|
|
with open(path) as file:
|
|
nrows, ncols, nnz = (int(el) for el in file.readline().split(", "))
|
|
index_pointers = (int(el) for el in file.readline().split())
|
|
indices = (int(el) for el in file.readline().split())
|
|
|
|
index_pointers = list(index_pointers)
|
|
indices = list(indices)
|
|
data = torch.randn(nnz, dtype=torch.double)
|
|
shape = (nrows, ncols)
|
|
return torch.sparse_coo_tensor(
|
|
csr_to_coo(indices, index_pointers, shape), data, shape, device=device
|
|
)
|
|
|
|
|
|
def gen_vector(path, device):
|
|
with open(path) as file:
|
|
nrows, ncols, nnz = (int(el) for el in file.readline().split(", "))
|
|
return torch.randn(nrows, dtype=torch.double, device=device)
|
|
|
|
|
|
def gen_matrix(path, device):
|
|
with open(path) as file:
|
|
nrows, ncols, nnz = (int(el) for el in file.readline().split(", "))
|
|
return torch.randn(nrows, ncols, dtype=torch.double, device=device)
|
|
|
|
|
|
def load_spmv_dataset(dataset_path, hidden_size, sparsity, device, n_limit=math.inf):
|
|
"""load_spmv_dataset loads a DLMC dataset for a sparse matrix-vector multiplication (SPMV) performance test.
|
|
Args:
|
|
dataset_path:
|
|
path of the dataset from DLMC collection.
|
|
hidden_size
|
|
This value allows tensors of varying sizes.
|
|
sparsity:
|
|
This value allows tensors of varying sparsities.
|
|
device:
|
|
Whether to place the Tensor on a GPU or CPU.
|
|
n_limit:
|
|
This value allows a dataset with some limit size.
|
|
"""
|
|
current_folder_path = f"{dataset_path}/{sparsity}"
|
|
path = Path(current_folder_path)
|
|
files = path.glob("**/*.smtx")
|
|
print(dataset_path, hidden_size, sparsity)
|
|
index = 0
|
|
x_files, y_files = [], []
|
|
for f in files:
|
|
if index >= n_limit:
|
|
break
|
|
print(".", end="")
|
|
size, nnz = read_matrix_params(f.as_posix())
|
|
if size[1] == hidden_size:
|
|
x_files.append(f.as_posix())
|
|
if size[0] == hidden_size:
|
|
y_files.append(f.as_posix())
|
|
index += 1 # noqa: SIM113
|
|
print()
|
|
|
|
for fx, fy in zip(x_files, y_files):
|
|
x = load_sparse_matrix(fx, device)
|
|
y = gen_vector(fy, device)
|
|
yield (x, y)
|
|
|
|
|
|
def load_spmm_dataset(
|
|
dataset_path, hidden_size, sparsity, spmm_type, device, n_limit=math.inf
|
|
):
|
|
"""load_spmm_dataset loads a DLMC dataset for a sparse matrix-matrix multiplication (SPMM) performance test.
|
|
Args:
|
|
dataset_path:
|
|
path of the dataset from DLMC collection.
|
|
hidden_size
|
|
This value allows tensors of varying sizes.
|
|
sparsity:
|
|
This value allows tensors of varying sparsities.
|
|
spmm_type:
|
|
This value allows tensors for `sparse@sparse` or `sparse@dense` operations.
|
|
device:
|
|
Whether to place the Tensor on a GPU or CPU.
|
|
n_limit:
|
|
This value allows a dataset with some limit size.
|
|
"""
|
|
current_folder_path = f"{dataset_path}/{sparsity}"
|
|
path = Path(current_folder_path)
|
|
files = path.glob("**/*.smtx")
|
|
print(dataset_path, hidden_size, sparsity)
|
|
index = 0
|
|
x_files, y_files = [], []
|
|
for f in files:
|
|
if index >= n_limit:
|
|
break
|
|
print(".", end="")
|
|
size, nnz = read_matrix_params(f.as_posix())
|
|
if size[1] == hidden_size:
|
|
x_files.append(f.as_posix())
|
|
if size[0] == hidden_size:
|
|
y_files.append(f.as_posix())
|
|
index += 1 # noqa: SIM113
|
|
print()
|
|
|
|
for fx, fy in zip(x_files, y_files):
|
|
x = load_sparse_matrix(fx, device)
|
|
y = (
|
|
gen_matrix(fy, device)
|
|
if spmm_type == "sparse@dense"
|
|
else load_sparse_matrix(fy, device)
|
|
)
|
|
yield (x, y)
|
|
|
|
|
|
def load_dlmc_dataset(
|
|
dataset_path,
|
|
operation,
|
|
hidden_size,
|
|
sparsity,
|
|
device,
|
|
requires_grad,
|
|
n_limit=math.inf,
|
|
):
|
|
"""load_dlmc_dataset loads a DLMC dataset for a matmul performance test.
|
|
Args:
|
|
dataset_path:
|
|
path of the dataset from DLMC collection.
|
|
operation:
|
|
This value allows tensors for `sparse@sparse`|`sparse@dense`|`sparse@vector` operations.
|
|
hidden_size
|
|
This value allows tensors of varying sizes.
|
|
sparsity:
|
|
This value allows tensors of varying sparsities.
|
|
device:
|
|
Whether to place the Tensor on a GPU or CPU.
|
|
requires_grad:
|
|
Loads the dataset for backward test.
|
|
n_limit:
|
|
This value allows a dataset with some limit size.
|
|
"""
|
|
if operation == "sparse@sparse" or operation == "sparse@dense":
|
|
collection = load_spmm_dataset(
|
|
dataset_path, hidden_size, sparsity, operation, device, n_limit
|
|
)
|
|
elif operation == "sparse@vector":
|
|
collection = load_spmv_dataset(
|
|
dataset_path, hidden_size, sparsity, device, n_limit
|
|
)
|
|
scipy_vars = {}
|
|
backward_vars = {}
|
|
for x, y in collection:
|
|
if device == "cpu":
|
|
scipy_vars = {
|
|
"sx": to_coo_scipy(x) if x.is_sparse else x.numpy(),
|
|
"sy": to_coo_scipy(y) if y.is_sparse else y.numpy(),
|
|
}
|
|
if not requires_grad:
|
|
dx = x.to_dense() if x.is_sparse else x
|
|
dy = y.to_dense() if y.is_sparse else y
|
|
else:
|
|
c = sparse_grad_output(x, y)
|
|
backward_vars = {
|
|
"sparse_grad_output": c,
|
|
"grad_output": c.to_dense() if c.is_sparse else c,
|
|
}
|
|
x.requires_grad_(True)
|
|
y.requires_grad_(True)
|
|
dx = x.to_dense().detach() if x.is_sparse else x.clone().detach()
|
|
dy = y.to_dense().detach() if y.is_sparse else y.clone().detach()
|
|
dx.requires_grad_(True)
|
|
dy.requires_grad_(True)
|
|
yield {"x": x, "y": y, "dx": dx, "dy": dy, **scipy_vars, **backward_vars}
|