mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The benchmark is failing with the following error ``` File "/var/lib/jenkins/workspace/benchmarks/gpt_fast/benchmark.py", line 333, in <module> main(output_file=args.output, only_model=args.only) File "/var/lib/jenkins/workspace/benchmarks/gpt_fast/benchmark.py", line 308, in main lst = func(device) File "/var/lib/jenkins/workspace/benchmarks/gpt_fast/benchmark.py", line 66, in run_mlp_layer_norm_gelu us_per_iter = benchmarker.benchmark(compiled_mod, (x,)) * 1000 File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper return fn(self, *args, **kwargs) TypeError: benchmark() missing 1 required positional argument: 'fn_kwargs' ``` An example error is https://github.com/pytorch/pytorch/actions/runs/12862761823/job/35858912555 I also assign `oncall: pt2` as the owner of this job going forward. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145235 Approved by: https://github.com/nmacchioni
336 lines
10 KiB
Python
336 lines
10 KiB
Python
import argparse
|
|
import csv
|
|
import dataclasses
|
|
import json
|
|
import os
|
|
|
|
from common import all_experiments, Experiment, register_experiment
|
|
from generate import get_arch_name
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch._inductor.runtime.benchmarking import benchmarker
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
|
|
|
|
WARMUP_ITER = 5
|
|
|
|
A100_40G_BF16_TFLOPS = 312
|
|
|
|
|
|
class SimpleMLP(nn.Module):
|
|
def __init__(self, input_dim, hidden_dim, output_dim, dtype):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
nn.Linear(input_dim, hidden_dim, dtype=dtype),
|
|
nn.LayerNorm(hidden_dim, dtype=dtype),
|
|
nn.Linear(hidden_dim, output_dim, dtype=dtype),
|
|
nn.LayerNorm(output_dim, dtype=dtype),
|
|
]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
@register_experiment(name="mlp_layer_norm_gelu")
|
|
def run_mlp_layer_norm_gelu(device: str = "cuda"):
|
|
dtype_flops_utilization_map = {
|
|
torch.bfloat16: "0.8",
|
|
}
|
|
input_shapes = [1024, 4096, 8192, 16384]
|
|
intermediate_size = 14336
|
|
results = []
|
|
for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
|
|
flops_utilization = 0
|
|
for D in input_shapes:
|
|
mod = SimpleMLP(
|
|
input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
|
|
).to(device)
|
|
|
|
x = torch.randn(D, device=device, dtype=torch.bfloat16)
|
|
|
|
with FlopCounterMode(display=False) as mode:
|
|
mod(x)
|
|
|
|
flops = mode.get_total_flops()
|
|
|
|
compiled_mod = torch.compile(mod, dynamic=False)
|
|
|
|
for _ in range(WARMUP_ITER):
|
|
compiled_mod(x)
|
|
|
|
us_per_iter = benchmarker.benchmark(compiled_mod, (x,), {}) * 1000
|
|
flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS
|
|
|
|
flops_utilization = flops_utilization / len(input_shapes)
|
|
dtype_str = str(dtype).replace("torch.", "")
|
|
results.append(
|
|
Experiment(
|
|
"mlp_layer_norm_gelu",
|
|
"flops_utilization",
|
|
expected_flops_utilization,
|
|
f"{flops_utilization:.02f}",
|
|
dtype_str,
|
|
device,
|
|
get_arch_name(),
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
@register_experiment(name="layer_norm")
|
|
def run_layer_norm(device: str = "cuda"):
|
|
dtype_memory_bandwidth_map = {
|
|
torch.bfloat16: "950",
|
|
}
|
|
input_shapes = [1024, 4096, 8192, 16384]
|
|
BS = 4096
|
|
results = []
|
|
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
|
|
memory_bandwidth = 0
|
|
for D in input_shapes:
|
|
mod = nn.LayerNorm(D).to(device)
|
|
|
|
x = torch.randn(BS, D, device=device, dtype=dtype)
|
|
|
|
compiled_mod = torch.compile(mod, dynamic=False)
|
|
|
|
for _ in range(WARMUP_ITER):
|
|
compiled_mod(x)
|
|
|
|
us_per_iter = benchmarker.benchmark(compiled_mod, (x,), {}) * 1000
|
|
memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
|
|
|
|
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
|
dtype_str = str(dtype).replace("torch.", "")
|
|
results.append(
|
|
Experiment(
|
|
"layer_norm",
|
|
"memory_bandwidth(GB/s)",
|
|
expected_memory_bandwidth,
|
|
f"{memory_bandwidth:.02f}",
|
|
dtype_str,
|
|
device,
|
|
get_arch_name(),
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
@register_experiment(name="gather_gemv")
|
|
@torch._inductor.config.patch(coordinate_descent_tuning=True)
|
|
def run_gather_gemv(device: str = "cuda"):
|
|
E = 8
|
|
dtype_memory_bandwidth_map = {
|
|
torch.int8: "990",
|
|
torch.bfloat16: "1060",
|
|
}
|
|
input_shapes = [1024, 4096, 8192, 16384]
|
|
results = []
|
|
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
|
|
memory_bandwidth = 0
|
|
for D in input_shapes:
|
|
|
|
def gather_gemv(W, score_idxs, x):
|
|
return W[score_idxs].to(x.dtype) @ x
|
|
|
|
W = torch.randn(E, D, D, device=device).to(dtype=dtype)
|
|
x = torch.randn(D, device=device, dtype=torch.bfloat16)
|
|
score_idxs = torch.tensor([3, 5], device=device)
|
|
|
|
compiled_fn = torch.compile(gather_gemv, dynamic=False)
|
|
|
|
for _ in range(WARMUP_ITER):
|
|
compiled_fn(W, score_idxs, x)
|
|
|
|
us_per_iter = (
|
|
benchmarker.benchmark(
|
|
compiled_fn,
|
|
(
|
|
W,
|
|
score_idxs,
|
|
x,
|
|
),
|
|
{},
|
|
)
|
|
* 1000
|
|
)
|
|
memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
|
|
|
|
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
|
dtype_str = str(dtype).replace("torch.", "")
|
|
results.append(
|
|
Experiment(
|
|
"gather_gemv",
|
|
"memory_bandwidth(GB/s)",
|
|
expected_memory_bandwidth,
|
|
f"{memory_bandwidth:.02f}",
|
|
dtype_str,
|
|
device,
|
|
get_arch_name(),
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
@register_experiment(name="gemv")
|
|
@torch._inductor.config.patch(coordinate_descent_tuning=True)
|
|
def run_gemv(device: str = "cuda"):
|
|
dtype_memory_bandwidth_map = {
|
|
torch.int8: "870",
|
|
torch.bfloat16: "990",
|
|
}
|
|
input_shapes = [1024, 4096, 8192, 16384]
|
|
results = []
|
|
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
|
|
memory_bandwidth = 0
|
|
for D in input_shapes:
|
|
|
|
def gemv(W, x):
|
|
return W.to(x.dtype) @ x
|
|
|
|
W = torch.randn(D, D, device=device).to(dtype=dtype)
|
|
x = torch.randn(D, device=device, dtype=torch.bfloat16)
|
|
|
|
compiled_fn = torch.compile(gemv, dynamic=False)
|
|
|
|
for _ in range(WARMUP_ITER):
|
|
compiled_fn(W, x)
|
|
|
|
us_per_iter = (
|
|
benchmarker.benchmark(
|
|
compiled_fn,
|
|
(
|
|
W,
|
|
x,
|
|
),
|
|
{},
|
|
)
|
|
* 1000
|
|
)
|
|
memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
|
|
|
|
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
|
dtype_str = str(dtype).replace("torch.", "")
|
|
results.append(
|
|
Experiment(
|
|
"gemv",
|
|
"memory_bandwidth(GB/s)",
|
|
expected_memory_bandwidth,
|
|
f"{memory_bandwidth:.02f}",
|
|
dtype_str,
|
|
device,
|
|
get_arch_name(),
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
def output_csv(output_file, headers, row):
|
|
if os.path.exists(output_file):
|
|
with open(output_file) as fd:
|
|
lines = list(csv.reader(fd)) or [[]]
|
|
if headers and len(headers) > len(lines[0]):
|
|
# if prior results failed the header might not be filled in yet
|
|
lines[0] = headers
|
|
else:
|
|
headers = lines[0]
|
|
else:
|
|
lines = [headers]
|
|
|
|
if output_file != DEFAULT_OUTPUT_FILE:
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
|
|
with open(output_file, "w") as fd:
|
|
writer = csv.writer(fd, lineterminator="\n")
|
|
for line in lines:
|
|
writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
|
|
|
|
|
|
def output_json(output_file, headers, row):
|
|
"""
|
|
Write the result into JSON format, so that it can be uploaded to the benchmark database
|
|
to be displayed on OSS dashboard. The JSON format is defined at
|
|
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
|
"""
|
|
mapping_headers = {headers[i]: v for i, v in enumerate(row)}
|
|
record = {
|
|
"benchmark": {
|
|
"name": "PyTorch gpt-fast benchmark",
|
|
"mode": "inference",
|
|
"dtype": mapping_headers["dtype"],
|
|
"extra_info": {
|
|
"device": mapping_headers["device"],
|
|
"arch": mapping_headers["arch"],
|
|
},
|
|
},
|
|
"model": {
|
|
"name": mapping_headers["name"],
|
|
"type": "OSS model" if mapping_headers["is_model"] else "micro-benchmark",
|
|
"origins": ["pytorch"],
|
|
},
|
|
"metric": {
|
|
"name": mapping_headers["metric"],
|
|
"benchmark_values": [mapping_headers["actual"]],
|
|
"target_value": mapping_headers["target"],
|
|
},
|
|
}
|
|
|
|
with open(f"{os.path.splitext(output_file)[0]}.json", "a") as f:
|
|
print(json.dumps(record), file=f)
|
|
|
|
|
|
DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
|
|
|
|
|
|
def main(output_file=DEFAULT_OUTPUT_FILE, only_model=None):
|
|
results = []
|
|
|
|
if not only_model:
|
|
experiments = all_experiments.values()
|
|
else:
|
|
if only_model not in all_experiments:
|
|
print(
|
|
f"Unknown model: {only_model}, all available models: {all_experiments.keys()}"
|
|
)
|
|
# only run the specified model
|
|
experiments = [all_experiments[only_model]]
|
|
for func in experiments:
|
|
try:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
except AssertionError:
|
|
# This happens when torch is compiled with CUDA turning off completely
|
|
device = "cpu"
|
|
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
lst = func(device)
|
|
for x in lst:
|
|
results.append(dataclasses.astuple(x))
|
|
|
|
headers = [field.name for field in dataclasses.fields(Experiment)]
|
|
|
|
for row in results:
|
|
output_csv(output_file, headers, row)
|
|
# Also write the output in JSON format so that it can be ingested into the OSS benchmark database
|
|
output_json(output_file, headers, row)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run experiments.")
|
|
parser.add_argument(
|
|
"--output",
|
|
default=DEFAULT_OUTPUT_FILE,
|
|
help="Set the output CSV file to save the benchmark results",
|
|
)
|
|
parser.add_argument(
|
|
"--only",
|
|
help="Specify a model or micro-benchmark name to run exclusively",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
main(output_file=args.output, only_model=args.only)
|