mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126 Approved by: https://github.com/kit1980
335 lines
9.0 KiB
Python
335 lines
9.0 KiB
Python
import argparse
|
|
import operator
|
|
import time
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
|
|
import torch
|
|
import torch._C._te as te
|
|
|
|
|
|
class kernel_arena_scope:
|
|
def __enter__(self):
|
|
self.scope = te.KernelScope()
|
|
|
|
def __exit__(self, typ, val, traceback):
|
|
self.scope = None
|
|
|
|
|
|
unary_ops = [
|
|
("sin", torch.sin),
|
|
("cos", torch.cos),
|
|
("tan", torch.tan),
|
|
("asin", torch.asin),
|
|
("acos", torch.acos),
|
|
("atan", torch.atan),
|
|
("sinh", torch.sinh),
|
|
("cosh", torch.cosh),
|
|
("tanh", torch.tanh),
|
|
("sigmoid", torch.sigmoid),
|
|
("exp", torch.exp),
|
|
("expm1", torch.expm1),
|
|
("expm1", torch.expm1),
|
|
("abs", torch.abs),
|
|
("log", torch.log),
|
|
("fast_log", torch.log),
|
|
("log2", torch.log2),
|
|
("log10", torch.log10),
|
|
("log1p", torch.log1p),
|
|
("erf", torch.erf),
|
|
("erfc", torch.erfc),
|
|
("sqrt", torch.sqrt),
|
|
("rsqrt", torch.rsqrt),
|
|
("ceil", torch.ceil),
|
|
("floor", torch.floor),
|
|
("round", torch.round),
|
|
("trunc", torch.trunc),
|
|
("lgamma", torch.lgamma),
|
|
# ("frac", torch.frac), # seems unimplemented
|
|
# ("isnan", torch.isnan), # no out variant
|
|
]
|
|
|
|
|
|
def gen_unary_nnc_fun(nnc_name):
|
|
def nnc_fun(A, B):
|
|
def compute(i, j):
|
|
return getattr(A.load([i, j]), nnc_name)()
|
|
|
|
return compute
|
|
|
|
return nnc_fun
|
|
|
|
|
|
def gen_unary_torch_fun(torch_op):
|
|
def torch_fun(a, b, out):
|
|
def fun():
|
|
return torch_op(a, out=out)
|
|
|
|
return fun
|
|
|
|
return torch_fun
|
|
|
|
|
|
def gen_binary_nnc_fun(fn):
|
|
def nnc_fun(A, B):
|
|
def compute(i, j):
|
|
return fn(A.load([i, j]), B.load([i, j]))
|
|
|
|
return compute
|
|
|
|
return nnc_fun
|
|
|
|
|
|
def gen_binary_torch_fun(fn):
|
|
def pt_fun(a, b, out):
|
|
def fun():
|
|
return fn(a, b, out=out)
|
|
|
|
return fun
|
|
|
|
return pt_fun
|
|
|
|
|
|
def gen_int_comparison_tensors(N, M):
|
|
return (
|
|
torch.randint(0, 3, (N, M)),
|
|
torch.randint(0, 3, (N, M)),
|
|
torch.empty((N, M), dtype=torch.bool),
|
|
)
|
|
|
|
|
|
def gen_float_comparison_tensors(N, M):
|
|
return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool))
|
|
|
|
|
|
te_bool = te.Dtype.Bool
|
|
binary_ops = [
|
|
("add", operator.add, torch.add),
|
|
("mul", operator.mul, torch.mul),
|
|
("sub", operator.sub, torch.sub),
|
|
("div", operator.truediv, torch.div),
|
|
(
|
|
"eq",
|
|
(lambda a, b: te.Cast.make(te_bool, a == b)),
|
|
torch.eq,
|
|
gen_int_comparison_tensors,
|
|
),
|
|
(
|
|
"gt",
|
|
(lambda a, b: te.Cast.make(te_bool, a > b)),
|
|
torch.gt,
|
|
gen_float_comparison_tensors,
|
|
),
|
|
(
|
|
"lt",
|
|
(lambda a, b: te.Cast.make(te_bool, a < b)),
|
|
torch.lt,
|
|
gen_float_comparison_tensors,
|
|
),
|
|
(
|
|
"gte",
|
|
(lambda a, b: te.Cast.make(te_bool, a >= b)),
|
|
torch.greater_equal,
|
|
gen_float_comparison_tensors,
|
|
),
|
|
(
|
|
"lte",
|
|
(lambda a, b: te.Cast.make(te_bool, a <= b)),
|
|
torch.less_equal,
|
|
gen_float_comparison_tensors,
|
|
),
|
|
# ('neq', (lambda a, b: a != b), None)), # no one-op equivalent
|
|
# ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test
|
|
]
|
|
|
|
|
|
def nnc_relu(A, B):
|
|
def f(i, j):
|
|
return torch._C._te.ifThenElse(
|
|
A.load([i, j]) < torch._C._te.ExprHandle.float(0),
|
|
torch._C._te.ExprHandle.float(0),
|
|
A.load([i, j]),
|
|
)
|
|
|
|
return f
|
|
|
|
|
|
def pt_relu(a, b, c):
|
|
return torch.relu(a)
|
|
|
|
|
|
custom_ops = [
|
|
("relu", nnc_relu, pt_relu),
|
|
# ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu)
|
|
# ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c))
|
|
]
|
|
|
|
|
|
def gen_custom_torch_fun(fn):
|
|
def pt_fun(a, b, out):
|
|
def fun():
|
|
return fn(a, b, out)
|
|
|
|
return fun
|
|
|
|
return pt_fun
|
|
|
|
|
|
def normalize_benchmarks(ops):
|
|
return [i + (None,) if len(i) == 3 else i for i in ops]
|
|
|
|
|
|
names = []
|
|
nnc_fns = []
|
|
pt_fns = []
|
|
shape_fns = []
|
|
|
|
for nnc_name, pt_op in unary_ops:
|
|
names.append(nnc_name)
|
|
nnc_fns.append(gen_unary_nnc_fun(nnc_name))
|
|
pt_fns.append(gen_unary_torch_fun(pt_op))
|
|
shape_fns.append(None)
|
|
|
|
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops):
|
|
names.append(name)
|
|
nnc_fns.append(gen_binary_nnc_fun(lmbda))
|
|
pt_fns.append(gen_binary_torch_fun(pt_fn))
|
|
shape_fns.append(shape_fn)
|
|
|
|
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops):
|
|
names.append(name)
|
|
nnc_fns.append(lmbda)
|
|
pt_fns.append(gen_custom_torch_fun(pt_fn))
|
|
shape_fns.append(shape_fn)
|
|
|
|
benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns))
|
|
|
|
|
|
def run_benchmarks(benchmarks, sizes):
|
|
df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"])
|
|
with torch.no_grad():
|
|
for name, nnc_fun, torch_fun, shape_fn in benchmarks:
|
|
for N, M in sizes:
|
|
iters = int(1e6 / (N + M))
|
|
with kernel_arena_scope():
|
|
if shape_fn is None:
|
|
tA = torch.rand(M, N).clamp(0.01, 0.99)
|
|
tB = torch.rand(M, N).clamp(0.01, 0.99)
|
|
tX = torch.empty(M, N)
|
|
tR = torch.empty(M, N)
|
|
else:
|
|
tA, tB, tX = shape_fn(M, N)
|
|
tR = tX.clone()
|
|
|
|
def get_nnc_type(dtype):
|
|
if dtype == torch.float:
|
|
return torch._C._te.Dtype.Float
|
|
elif dtype == torch.long:
|
|
return torch._C._te.Dtype.Long
|
|
|
|
dtype = get_nnc_type(tA.dtype)
|
|
|
|
dM = torch._C._te.ExprHandle.int(M)
|
|
dN = torch._C._te.ExprHandle.int(N)
|
|
|
|
A = torch._C._te.Placeholder("A", dtype, [dM, dN])
|
|
B = torch._C._te.Placeholder("B", dtype, [dM, dN])
|
|
|
|
dim_args = [
|
|
torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")]
|
|
]
|
|
|
|
compute = nnc_fun(A, B)
|
|
X = torch._C._te.Compute("X", dim_args, compute)
|
|
loopnest = torch._C._te.LoopNest([X])
|
|
loopnest.prepare_for_codegen()
|
|
stmt = torch._C._te.simplify(loopnest.root_stmt())
|
|
cg = torch._C._te.construct_codegen(
|
|
"llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]]
|
|
)
|
|
|
|
# warmup
|
|
for _ in range(10):
|
|
cg.call([tA, tB, tX])
|
|
start = time.time()
|
|
for it in range(iters):
|
|
cg.call([tA, tB, tX])
|
|
time1 = time.time() - start
|
|
|
|
fn = torch_fun(tA, tB, tR)
|
|
# warmup
|
|
for _ in range(10):
|
|
tR = fn()
|
|
start = time.time()
|
|
for it in range(iters):
|
|
tR = fn()
|
|
time2 = time.time() - start
|
|
|
|
df = df.append(
|
|
{
|
|
"name": name,
|
|
"N": N,
|
|
"M": M,
|
|
"nnc_time": time1,
|
|
"torch_time": time2,
|
|
"ratio": time2 / time1,
|
|
},
|
|
ignore_index=True,
|
|
)
|
|
print(name, N, M)
|
|
|
|
print(time2 / time1, time1, time2)
|
|
print()
|
|
|
|
def check_correctness(a, b):
|
|
if not np.allclose(a, b):
|
|
print(name)
|
|
assert np.allclose(a, b)
|
|
|
|
check_correctness(tX, tR)
|
|
return df
|
|
|
|
|
|
def dump_plot(df, sizes):
|
|
keys = []
|
|
vals = []
|
|
indexed = df[df["N"] == df["M"]]
|
|
for index, row in indexed.iterrows():
|
|
keys.append(row["name"])
|
|
vals.append(row["ratio"])
|
|
|
|
keys = keys[:: len(sizes)]
|
|
sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)})
|
|
|
|
cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True)
|
|
np_vals = np.array([vals]).reshape(-1, len(sizes))
|
|
g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True)
|
|
plt.yticks(rotation=0)
|
|
plt.title("PyTorch performance divided by NNC performance (single core)")
|
|
plt.xlabel("Size of NxN matrix")
|
|
plt.ylabel("Operation")
|
|
g.set_yticklabels(keys)
|
|
g.set_xticklabels(sizes)
|
|
|
|
plt.savefig("nnc.png")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks")
|
|
parser.add_argument(
|
|
"--multi-threaded",
|
|
"--multi_threaded",
|
|
action="store_true",
|
|
help="Run with more than one thread",
|
|
)
|
|
args = parser.parse_args()
|
|
if not args.multi_threaded:
|
|
torch.set_num_threads(1)
|
|
|
|
sizes = [1, 4, 16, 64, 256, 1024]
|
|
df = run_benchmarks(benchmarks, [(i, i) for i in sizes])
|
|
dump_plot(df, sizes)
|