Files
pytorch/benchmarks/sparse/benchmark_semi_structured_sparsity.py
PyTorch MergeBot 5318bf8baf Revert "[sparse] add extra options to _cslt_spare_mm (#137427)"
This reverts commit f1451163ecd2bd014cb80a40c41c9999fbc94af8.

Reverted https://github.com/pytorch/pytorch/pull/137427 on behalf of https://github.com/huydhn due to This looks like the test is still failing, plz do a rebase ([comment](https://github.com/pytorch/pytorch/pull/137427#issuecomment-2499918590))
2024-11-26 08:01:24 +00:00

254 lines
6.5 KiB
Python

import argparse
import random
import pandas as pd
from tqdm import tqdm
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
torch.set_printoptions(
precision=2,
threshold=None,
edgeitems=16,
linewidth=480,
profile=None,
sci_mode=False,
)
# helper model definition for pruner
class Model(nn.Module):
def __init__(self, m, k, dtype=None):
super().__init__()
# transposed so reversed
self.linear = nn.Linear(k, m)
def forward(self, x):
return self.linear(x)
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
model = Model(m, k).to(dtype).cuda().eval()
dense_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
dense_output = model(input_tensor)
print(dense_output.shape)
# sparsify weights
model.linear.weight = nn.Parameter(
to_sparse_semi_structured(
sparse_weight,
)
)
sparse_output = model(input_tensor)
print(sparse_output.shape)
sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "linear",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
def test_tensor(m, k, n, dtype, contiguous, backend):
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
B = torch.zeros(k, n).to(dtype).cuda()
bias = torch.rand(n).to(dtype).cuda()
sA = to_sparse_semi_structured(A)
# torch.mm calculation
if dtype is not torch.int8:
dense_output = torch.mm(A, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(A, B)",
globals=locals(),
).blocked_autorange()
else:
print("int8 baseline not supported")
dense_output = torch.mm(sA, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
sparse_output = torch.mm(sA, B)
sparse_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "tensor",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
if __name__ == "__main__":
dtype_lookup = {
"int8": torch.int8,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
parser.add_argument(
"--mode",
type=str,
choices=[
"nvidia-bert",
"nvidia-fixed-k",
"nvidia-fixed-mn",
],
)
parser.add_argument(
"--dtype",
type=str,
choices=dtype_lookup.keys(),
default="fp16",
)
parser.add_argument(
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
)
parser.add_argument("-contiguous", action="store_true")
parser.add_argument("-e2e", action="store_true")
parser.add_argument("-save", action="store_true")
args = parser.parse_args()
if args.e2e:
eval_fn = test_linear
else:
eval_fn = test_tensor
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
dtype = dtype_lookup[args.dtype]
if args.mode == "nvidia-bert":
bert_shapes = [
(3072, 1024, 16384),
(4096, 1024, 16384),
(1024, 1024, 16384),
(1024, 4096, 16384),
]
results = (
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
for (m, k, n) in tqdm(bert_shapes)
)
elif args.mode == "nvidia-fixed-k":
mn_vals = [
3072,
4096,
5120,
6144,
7168,
8192,
9216,
10240,
11264,
12288,
13312,
14336,
15360,
16384,
17408,
18432,
19456,
20480,
]
results = (
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
for mn in tqdm(mn_vals)
)
elif args.mode == "nvidia-fixed-mn":
k_vals = [
2560,
3840,
5120,
6400,
7680,
8960,
10240,
11520,
12800,
14080,
15360,
16640,
17920,
19200,
20480,
]
results = (
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
for k in tqdm(k_vals)
)
df = pd.DataFrame.from_records(results)
if args.save:
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
df.to_csv(save_file)
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
print(df)