Files
pytorch/benchmarks/sparse/benchmark_semi_structured_sparsity.py
Xuehai Pan 26f4f10ac8 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
2024-05-27 14:49:57 +00:00

254 lines
6.5 KiB
Python

import argparse
import random
import pandas as pd
from tqdm import tqdm
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
torch.set_printoptions(
precision=2,
threshold=None,
edgeitems=16,
linewidth=480,
profile=None,
sci_mode=False,
)
# helper model definition for pruner
class Model(nn.Module):
def __init__(self, m, k, dtype=None):
super().__init__()
# transposed so reversed
self.linear = nn.Linear(k, m)
def forward(self, x):
return self.linear(x)
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
model = Model(m, k).to(dtype).cuda().eval()
dense_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
dense_output = model(input_tensor)
print(dense_output.shape)
# sparsify weights
model.linear.weight = nn.Parameter(
to_sparse_semi_structured(
sparse_weight,
)
)
sparse_output = model(input_tensor)
print(sparse_output.shape)
sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "linear",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
def test_tensor(m, k, n, dtype, contiguous, backend):
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
B = torch.zeros(k, n).to(dtype).cuda()
bias = torch.rand(n).to(dtype).cuda()
sA = to_sparse_semi_structured(A)
# torch.mm calculation
if dtype is not torch.int8:
dense_output = torch.mm(A, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(A, B)",
globals=locals(),
).blocked_autorange()
else:
print("int8 baseline not supported")
dense_output = torch.mm(sA, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
sparse_output = torch.mm(sA, B)
sparse_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "tensor",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
if __name__ == "__main__":
dtype_lookup = {
"int8": torch.int8,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
parser.add_argument(
"--mode",
type=str,
choices=[
"nvidia-bert",
"nvidia-fixed-k",
"nvidia-fixed-mn",
],
)
parser.add_argument(
"--dtype",
type=str,
choices=dtype_lookup.keys(),
default="fp16",
)
parser.add_argument(
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
)
parser.add_argument("-contiguous", action="store_true")
parser.add_argument("-e2e", action="store_true")
parser.add_argument("-save", action="store_true")
args = parser.parse_args()
if args.e2e:
eval_fn = test_linear
else:
eval_fn = test_tensor
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
dtype = dtype_lookup[args.dtype]
if args.mode == "nvidia-bert":
bert_shapes = [
(3072, 1024, 16384),
(4096, 1024, 16384),
(1024, 1024, 16384),
(1024, 4096, 16384),
]
results = (
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
for (m, k, n) in tqdm(bert_shapes)
)
elif args.mode == "nvidia-fixed-k":
mn_vals = [
3072,
4096,
5120,
6144,
7168,
8192,
9216,
10240,
11264,
12288,
13312,
14336,
15360,
16384,
17408,
18432,
19456,
20480,
]
results = (
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
for mn in tqdm(mn_vals)
)
elif args.mode == "nvidia-fixed-mn":
k_vals = [
2560,
3840,
5120,
6400,
7680,
8960,
10240,
11520,
12800,
14080,
15360,
16640,
17920,
19200,
20480,
]
results = (
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
for k in tqdm(k_vals)
)
df = pd.DataFrame.from_records(results)
if args.save:
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
df.to_csv(save_file)
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
print(df)