mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CI] enable operator benchmark on CPU (#143733)
This is to enable operator benchmark for CPU to track op level performance. This PR is motivated by PR: https://github.com/pytorch/pytorch/issues/120982 and investigate feasibility in https://github.com/pytorch/pytorch/pull/127216 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143733 Approved by: https://github.com/leslie-fang-intel, https://github.com/atalman, https://github.com/huydhn, https://github.com/malfet Co-authored-by: diwei sun <diwei.sun@intel.com> Co-authored-by: chuanqiw <chuanqi.wang@intel.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
700260f166
commit
fa5f556f88
@ -1527,6 +1527,27 @@ test_linux_aarch64() {
|
||||
--shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
|
||||
}
|
||||
|
||||
test_operator_benchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
|
||||
test_inductor_set_cpu_affinity
|
||||
|
||||
cd benchmarks/operator_benchmark/pt_extension
|
||||
python setup.py install
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
||||
$TASKSET python -m benchmark_all_test --device "$1" --tag-filter "$2" \
|
||||
--output-dir "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv"
|
||||
|
||||
pip_install pandas
|
||||
python check_perf_csv.py \
|
||||
--actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \
|
||||
--expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
|
||||
}
|
||||
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
|
||||
@ -1557,6 +1578,19 @@ elif [[ "$TEST_CONFIG" == distributed ]]; then
|
||||
if [[ "${SHARD_NUMBER}" == 1 ]]; then
|
||||
test_rpc
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
|
||||
TEST_MODE="short"
|
||||
|
||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
||||
if [[ "${TEST_CONFIG}" == *long* ]]; then
|
||||
TEST_MODE="long"
|
||||
elif [[ "${TEST_CONFIG}" == *all* ]]; then
|
||||
TEST_MODE="all"
|
||||
fi
|
||||
|
||||
test_operator_benchmark cpu ${TEST_MODE}
|
||||
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
|
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -25,6 +25,7 @@ ciflow_push_tags:
|
||||
- ciflow/xpu
|
||||
- ciflow/torchbench
|
||||
- ciflow/autoformat
|
||||
- ciflow/op-benchmark
|
||||
retryable_workflows:
|
||||
- pull
|
||||
- trunk
|
||||
|
56
.github/workflows/operator_benchmark.yml
vendored
Normal file
56
.github/workflows/operator_benchmark.yml
vendored
Normal file
@ -0,0 +1,56 @@
|
||||
name: operator_benchmark
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/op-benchmark/*
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test_mode:
|
||||
required: false
|
||||
type: string
|
||||
default: 'short'
|
||||
description: tag filter for operator benchmarks, options from long, short, all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
linux-jammy-cpu-py3_9-gcc11-opbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: linux-jammy-cpu-py3.9-gcc11-opbenchmark
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-py3.9-gcc11-build
|
||||
docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cpu-py3_9-gcc11-opbenchmark-on-demand-build:
|
||||
if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-cpu-py3.9-gcc11-opbenchmark
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-py3.9-gcc11-build
|
||||
docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cpu-py3_9-gcc11-opbenchmark-test:
|
||||
name: linux-jammy-cpu-py3.9-gcc11-opbenchmark
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: linux-jammy-cpu-py3_9-gcc11-opbenchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.9-gcc11-build
|
||||
docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-opbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-opbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
@ -1,7 +1,9 @@
|
||||
import ast
|
||||
import copy
|
||||
import csv
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import timeit
|
||||
from collections import namedtuple
|
||||
|
||||
@ -31,6 +33,8 @@ TestConfig = namedtuple("TestConfig", "test_name input_config tag run_backward")
|
||||
|
||||
BENCHMARK_TESTER = []
|
||||
|
||||
SKIP_OP_LISTS = ["weight_norm_sparsifier_step"]
|
||||
|
||||
|
||||
def _register_test(*test_metainfo):
|
||||
"""save the metainfo needed to create a test. Currently test_metainfo
|
||||
@ -187,7 +191,9 @@ class BenchmarkRunner:
|
||||
self.use_jit = args.use_jit
|
||||
self.num_runs = args.num_runs
|
||||
self.print_per_iter = False
|
||||
self.output_dir = args.output_dir
|
||||
self.operator_range = benchmark_utils.get_operator_range(args.operator_range)
|
||||
self.disable_output = args.disable_output
|
||||
# 100 is the default warmup iterations
|
||||
if self.args.warmup_iterations == -1:
|
||||
self.args.warmup_iterations = 100
|
||||
@ -397,6 +403,9 @@ class BenchmarkRunner:
|
||||
test_flag == cmd_flag for cmd_flag in cmd_flag_list
|
||||
)
|
||||
|
||||
def _check_skip(self, test_module, cmd_flag):
|
||||
return cmd_flag is None or (test_module not in cmd_flag)
|
||||
|
||||
def _keep_test(self, test_case):
|
||||
# TODO: consider regex matching for test filtering.
|
||||
# Currently, this is a sub-string matching.
|
||||
@ -412,6 +421,7 @@ class BenchmarkRunner:
|
||||
return (
|
||||
self._check_keep(op_test_config.test_name, self.args.test_name)
|
||||
and self._check_keep_list(test_case.op_bench.module_name(), operators)
|
||||
and self._check_skip(test_case.op_bench.module_name(), SKIP_OP_LISTS)
|
||||
and self._check_operator_first_char(
|
||||
test_case.op_bench.module_name(), self.operator_range
|
||||
)
|
||||
@ -446,8 +456,36 @@ class BenchmarkRunner:
|
||||
|
||||
return False
|
||||
|
||||
def _output_csv(self, filename, headers, row):
|
||||
if self.args.disable_output is True:
|
||||
return
|
||||
if os.path.exists(filename):
|
||||
with open(filename) as fd:
|
||||
lines = list(csv.reader(fd)) or [[]]
|
||||
if headers and len(headers) > len(lines[0]):
|
||||
# if prior results failed the header might not be filled in yet
|
||||
lines[0] = headers
|
||||
else:
|
||||
headers = lines[0]
|
||||
else:
|
||||
lines = [headers]
|
||||
lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
|
||||
with open(filename, "w") as fd:
|
||||
writer = csv.writer(fd, lineterminator="\n")
|
||||
for line in lines:
|
||||
writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
|
||||
|
||||
def run(self):
|
||||
self._print_header()
|
||||
output_filename = self.args.output_dir
|
||||
headers = [
|
||||
"Benchmarking Framework",
|
||||
"Benchamrking Module Name",
|
||||
"Case Name",
|
||||
"tag",
|
||||
"run_backward",
|
||||
"Execution Time",
|
||||
]
|
||||
|
||||
if self.args.output_json:
|
||||
perf_list = []
|
||||
@ -490,8 +528,25 @@ class BenchmarkRunner:
|
||||
)
|
||||
for _ in range(self.num_runs)
|
||||
]
|
||||
|
||||
self._print_perf_result(reported_time, test_case)
|
||||
|
||||
# output results to csv
|
||||
self._output_csv(
|
||||
output_filename,
|
||||
headers,
|
||||
[
|
||||
test_case.framework,
|
||||
test_case.op_bench.module_name(),
|
||||
(
|
||||
test_case.test_config.test_name + "_BACKWARD"
|
||||
if test_case.test_config.run_backward is True
|
||||
else test_case.test_config.test_name
|
||||
),
|
||||
test_case.test_config.tag,
|
||||
test_case.test_config.run_backward,
|
||||
reported_time[0],
|
||||
],
|
||||
)
|
||||
if self.args.output_json:
|
||||
perf_list.append(
|
||||
self._perf_result_to_dict(reported_time, test_case)
|
||||
|
@ -150,6 +150,17 @@ def parse_args():
|
||||
default="None",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
help="Choose the output directory to save the logs",
|
||||
default="benchmark_logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-output",
|
||||
help="Disable log output to csv file",
|
||||
default="False",
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.omp_num_threads:
|
||||
|
116
benchmarks/operator_benchmark/check_perf_csv.py
Normal file
116
benchmarks/operator_benchmark/check_perf_csv.py
Normal file
@ -0,0 +1,116 @@
|
||||
import argparse
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
SKIP_TEST_LISTS = [
|
||||
# https://github.com/pytorch/pytorch/issues/143852
|
||||
"channel_shuffle_batch_size4_channels_per_group64_height64_width64_groups4_channel_lastTrue",
|
||||
"batchnorm_N3136_C256_cpu_trainingTrue_cudnnFalse",
|
||||
"index_add__M256_N512_K1_dim1_cpu_dtypetorch.float32",
|
||||
"interpolate_input_size(1,3,600,400)_output_size(240,240)_channels_lastTrue_modelinear",
|
||||
"original_kernel_tensor_N1_C3_H512_W512_zero_point_dtypetorch.int32_nbits4_cpu",
|
||||
"original_kernel_tensor_N1_C3_H512_W512_zero_point_dtypetorch.int32_nbits8_cpu",
|
||||
]
|
||||
|
||||
|
||||
def get_field(csv, case: str, field: str):
|
||||
try:
|
||||
return csv.loc[csv["Case Name"] == case][field].item()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def check_perf(actual_csv, expected_csv, expected_filename, threshold):
|
||||
failed = []
|
||||
improved = []
|
||||
baseline_not_found = []
|
||||
|
||||
actual_csv = actual_csv[~actual_csv["Case Name"].isin(set(SKIP_TEST_LISTS))]
|
||||
|
||||
for case in actual_csv["Case Name"]:
|
||||
perf = get_field(actual_csv, case, "Execution Time")
|
||||
expected_perf = get_field(expected_csv, case, "Execution Time")
|
||||
|
||||
if expected_perf is None:
|
||||
status = "Baseline Not Found"
|
||||
print(f"{case:34} {status}")
|
||||
baseline_not_found.append(case)
|
||||
continue
|
||||
|
||||
speed_up = expected_perf / perf
|
||||
|
||||
if (1 - threshold) <= speed_up < (1 + threshold):
|
||||
status = "PASS"
|
||||
print(f"{case:34} {status}")
|
||||
continue
|
||||
elif speed_up >= 1 + threshold:
|
||||
status = "IMPROVED:"
|
||||
improved.append(case)
|
||||
else:
|
||||
status = "FAILED:"
|
||||
failed.append(case)
|
||||
print(f"{case:34} {status:9} perf={perf}, expected={expected_perf}")
|
||||
|
||||
msg = ""
|
||||
if failed or improved or baseline_not_found:
|
||||
if failed:
|
||||
msg += textwrap.dedent(
|
||||
f"""
|
||||
Error: {len(failed)} models have performance status regressed:
|
||||
{" ".join(failed)}
|
||||
|
||||
"""
|
||||
)
|
||||
if improved:
|
||||
msg += textwrap.dedent(
|
||||
f"""
|
||||
Improvement: {len(improved)} models have performance status improved:
|
||||
{" ".join(improved)}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
if baseline_not_found:
|
||||
msg += textwrap.dedent(
|
||||
f"""
|
||||
Baseline Not Found: {len(baseline_not_found)} models don't have the baseline data:
|
||||
{" ".join(baseline_not_found)}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
msg += textwrap.dedent(
|
||||
f"""
|
||||
If this change is expected, you can update `{expected_filename}` to reflect the new baseline.
|
||||
"""
|
||||
)
|
||||
return failed or improved or baseline_not_found, msg
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--actual", type=str, required=True)
|
||||
parser.add_argument("--expected", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="threshold to define regression/improvement",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
actual = pd.read_csv(args.actual)
|
||||
actual.drop_duplicates(subset=["Case Name"], keep="first", inplace=True)
|
||||
expected = pd.read_csv(args.expected)
|
||||
|
||||
failed, msg = check_perf(actual, expected, args.expected, args.threshold)
|
||||
if failed:
|
||||
print(msg)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
@ -43,6 +43,10 @@ class ReplaceNaNBenchmark(op_bench.TorchBenchmarkBase):
|
||||
self.op_func = op_func
|
||||
self.set_module_name("nan_to_num")
|
||||
|
||||
# To make casename unique as nan_to_num and nan_to_num_ are two different functions.
|
||||
if op_func is torch.nan_to_num_:
|
||||
self.set_module_name("nan_to_num_")
|
||||
|
||||
def forward(self, input, replace_inf: bool):
|
||||
# compare inplace
|
||||
if replace_inf:
|
||||
|
@ -193,8 +193,8 @@ def fakeQuantizePerTensorOriginalKernel(
|
||||
|
||||
fake_quantize_per_tensor_ops = op_bench.op_list(
|
||||
attrs=(
|
||||
("learnable_kernel", fakeQuantizePerTensorLearnableKernel),
|
||||
("original_kernel", fakeQuantizePerTensorOriginalKernel),
|
||||
("learnable_kernel_tensor", fakeQuantizePerTensorLearnableKernel),
|
||||
("original_kernel_tensor", fakeQuantizePerTensorOriginalKernel),
|
||||
),
|
||||
attr_names=("op_name", "op_func"),
|
||||
)
|
||||
@ -297,8 +297,8 @@ def fakeQuantizePerChannelOriginalKernel(
|
||||
|
||||
fake_quantize_per_channel_ops = op_bench.op_list(
|
||||
attrs=(
|
||||
("learnable_kernel", fakeQuantizePerChannelLearnableKernel),
|
||||
("original_kernel", fakeQuantizePerChannelOriginalKernel),
|
||||
("learnable_kernel_channel", fakeQuantizePerChannelLearnableKernel),
|
||||
("original_kernel_channel", fakeQuantizePerChannelOriginalKernel),
|
||||
),
|
||||
attr_names=("op_name", "op_func"),
|
||||
)
|
||||
|
Reference in New Issue
Block a user