mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add instruction count benchmark to run on pull requests (#131475)
This PR only adds the execution of the benchmarks on this PR and print results, following diffs will add checking out head~1 and running it and comparing. to access results goto test pr_time_benchmarks and inspect logs: you should see ``` + echo 'benchmark results on current PR: ' benchmark results on current PR: + cat /var/lib/jenkins/workspace/test/test-reports/pr_time_benchmarks_before.txt update_hint_regression,instruction_count,27971461254 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131475 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
27c44c884e
commit
f5e704a6f2
@ -392,7 +392,20 @@ test_inductor_cpp_wrapper_abi_compatible() {
|
||||
# .github/workflows/inductor-perf-test-nightly.yml
|
||||
DYNAMO_BENCHMARK_FLAGS=()
|
||||
|
||||
if [[ "${TEST_CONFIG}" == *dynamo_eager* ]]; then
|
||||
pr_time_benchmarks() {
|
||||
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
source benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt" "benchmarks/dynamo/pr_time_benchmarks/benchmarks"
|
||||
echo "benchmark results on current PR: "
|
||||
cat "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt"
|
||||
|
||||
}
|
||||
|
||||
if [[ "${TEST_CONFIG}" == *pr_time_benchmarks* ]]; then
|
||||
pr_time_benchmarks
|
||||
exit 0
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_eager* ]]; then
|
||||
DYNAMO_BENCHMARK_FLAGS+=(--backend eager)
|
||||
elif [[ "${TEST_CONFIG}" == *aot_eager* ]]; then
|
||||
DYNAMO_BENCHMARK_FLAGS+=(--backend aot_eager)
|
||||
|
22
.github/workflows/pull.yml
vendored
22
.github/workflows/pull.yml
vendored
@ -110,7 +110,6 @@ jobs:
|
||||
{ config: "default", shard: 1, num_shards: 1 },
|
||||
]}
|
||||
|
||||
|
||||
linux-jammy-py3_10-clang15-asan-build:
|
||||
name: linux-jammy-py3.10-clang15-asan
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
@ -571,3 +570,24 @@ jobs:
|
||||
docker-image: ${{ needs.linux-focal-py3_12-clang10-experimental-split-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-py3_12-clang10-experimental-split-build.outputs.test-matrix }}
|
||||
timeout-minutes: 600
|
||||
|
||||
linux-focal-cuda12_1-py3_10-gcc9-inductor-build:
|
||||
name: cuda12.1-py3.10-gcc9-sm75
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm75
|
||||
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks
|
||||
cuda-arch-list: '7.5'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
|
||||
]}
|
||||
|
||||
linux-focal-cuda12_1-py3_10-gcc9-inductor-test:
|
||||
name: cuda12.1-py3.10-gcc9-sm75
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: linux-focal-cuda12_1-py3_10-gcc9-inductor-build
|
||||
with:
|
||||
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm75
|
||||
docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
|
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
0
benchmarks/dynamo/pr_time_benchmarks/__init__.py
Normal file
0
benchmarks/dynamo/pr_time_benchmarks/__init__.py
Normal file
64
benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py
Normal file
64
benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py
Normal file
@ -0,0 +1,64 @@
|
||||
import csv
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch._C._instruction_counter as i_counter
|
||||
|
||||
|
||||
class BenchmarkBase(ABC):
|
||||
_instruction_count = False
|
||||
|
||||
def enable_instruction_count(self):
|
||||
self._instruction_count = True
|
||||
return self
|
||||
|
||||
def name(self):
|
||||
return ""
|
||||
|
||||
def description(self):
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def work(self):
|
||||
pass
|
||||
|
||||
def prepare_once(self): # noqa: B027
|
||||
pass
|
||||
|
||||
def count_instructions(self):
|
||||
print(f"collecting instruction count for {self.name()}")
|
||||
self.prepare_once()
|
||||
|
||||
results = []
|
||||
for i in range(10):
|
||||
self.prepare()
|
||||
id = i_counter.start()
|
||||
self.work()
|
||||
count = i_counter.end(id)
|
||||
print(f"instruction count for iteration {i} is {count}")
|
||||
if i != 0:
|
||||
results.append(count)
|
||||
return min(results)
|
||||
|
||||
def append_results(self, path):
|
||||
with open(path, "a", newline="") as csvfile:
|
||||
# Create a writer object
|
||||
writer = csv.writer(csvfile)
|
||||
# Write the data to the CSV file
|
||||
for entry in self.results:
|
||||
writer.writerow(entry)
|
||||
|
||||
def print(self):
|
||||
for entry in self.results:
|
||||
print(f"{entry[0]},{entry[1]},{entry[2]}")
|
||||
|
||||
def collect_all(self):
|
||||
self.results = []
|
||||
if self._instruction_count:
|
||||
self.results.append(
|
||||
(self.name(), "instruction_count", self.count_instructions())
|
||||
)
|
||||
return self
|
25
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file
25
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
# Check if the output file argument was provided
|
||||
if [ $# -eq 0 ]
|
||||
then
|
||||
echo "Please provide the output file as an argument"
|
||||
return
|
||||
fi
|
||||
|
||||
# Check if the directory of Python programs argument was provided
|
||||
if [ $# -eq 1 ]
|
||||
then
|
||||
echo "Please provide the directory of Python programs as an argument"
|
||||
return
|
||||
fi
|
||||
|
||||
# Set the output file
|
||||
output_file=$1
|
||||
# Set the directory of Python programs
|
||||
python_programs_dir=$2
|
||||
# Loop through all files in the directory of Python programs
|
||||
for file in $python_programs_dir/*.py
|
||||
do
|
||||
# Execute the Python program and append the output to the output file
|
||||
sudo env PATH="$PATH" python $file $output_file
|
||||
done
|
@ -0,0 +1,46 @@
|
||||
import random
|
||||
import sys
|
||||
|
||||
from benchmarks.dynamo.pr_time_benchmarks.benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
N = 20
|
||||
|
||||
def name(self):
|
||||
return "update_hint_regression"
|
||||
|
||||
def description(self):
|
||||
return "information at https://github.com/pytorch/pytorch/pull/129893"
|
||||
|
||||
def prepare_once(self):
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
random.seed(42)
|
||||
self.splits = torch.randint(10, (self.N,))
|
||||
sz = self.splits.sum().item()
|
||||
self.input = torch.randn(sz)
|
||||
|
||||
def prepare(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def work(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(a, b):
|
||||
xs = b.tolist()
|
||||
for x in xs:
|
||||
torch._check_is_size(x)
|
||||
torch._check(x <= self.N)
|
||||
return a.split(xs)
|
||||
|
||||
f(self.input, self.splits)
|
||||
|
||||
|
||||
def main():
|
||||
result_path = sys.argv[1]
|
||||
Benchmark().enable_instruction_count().collect_all().append_results(result_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,4 +1,5 @@
|
||||
"""Collect instruction counts for continuous integration."""
|
||||
# mypy: ignore-errors
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Key enums and structs used to handle data flow within the benchmark."""
|
||||
# mypy: ignore-errors
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools as it
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
This is mostly string manipulation, with just a bit of importlib magic.
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
import importlib.abc
|
||||
import importlib.util
|
||||
import itertools as it
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Type annotations for various benchmark objects."""
|
||||
# mypy: ignore-errors
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from core.api import AutoLabels, GroupedBenchmark, TimerArgs
|
||||
|
@ -1,3 +1,4 @@
|
||||
# mypy: ignore-errors
|
||||
import atexit
|
||||
import re
|
||||
import shutil
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Define some common setup blocks which benchmarks can reuse."""
|
||||
|
||||
# mypy: ignore-errors
|
||||
import enum
|
||||
|
||||
from core.api import GroupedSetup
|
||||
|
@ -11,6 +11,7 @@ Parser notes:
|
||||
- To set a label for the succeeding block, add `# @YOUR_LABEL` (Python)
|
||||
or `// @YOUR_LABEL` (C++).
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
|
||||
from core.api import GroupedModules, GroupedStmts, GroupedVariants
|
||||
from core.types import FlatIntermediateDefinition
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Run benchmarks while handling parallelism, isolation, and fault tolerance."""
|
||||
# mypy: ignore-errors
|
||||
import math
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Handle the details of subprocess calls and retries for a given benchmark run."""
|
||||
# mypy: ignore-errors
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
|
@ -5,6 +5,7 @@ expressive and robust components (e.g. better runner and result display
|
||||
components) in future iterations. However this allows us to excercise the
|
||||
underlying benchmark generation infrastructure in the mean time.
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List
|
||||
|
@ -1,3 +1,4 @@
|
||||
import operator_benchmark as op_bench
|
||||
from pt import ( # noqa: F401
|
||||
add_test,
|
||||
ao_sparsifier_test,
|
||||
@ -29,8 +30,6 @@ from pt import ( # noqa: F401
|
||||
tensor_to_test,
|
||||
)
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import operator_benchmark as op_bench
|
||||
from pt import ( # noqa: F401
|
||||
qactivation_test,
|
||||
qarithmetic_test,
|
||||
@ -21,8 +22,6 @@ from pt import ( # noqa: F401
|
||||
qunary_test,
|
||||
)
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
@ -1,8 +1,7 @@
|
||||
import benchmark_all_other_test # noqa: F401
|
||||
import benchmark_all_quantized_test # noqa: F401
|
||||
from pt import unary_test # noqa: F401
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import unary_test # noqa: F401
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,6 +1,5 @@
|
||||
from pt import configs
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import configs
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -1,7 +1,6 @@
|
||||
import numpy
|
||||
from pt import configs
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import configs
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import numpy
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
@ -1,5 +1,4 @@
|
||||
import numpy
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
@ -1,6 +1,5 @@
|
||||
from pt import configs
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import configs
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -1,7 +1,6 @@
|
||||
import numpy
|
||||
from pt import configs
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import configs
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.qat as nnqat
|
||||
|
@ -1,6 +1,5 @@
|
||||
from pt import configs
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import configs
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.quantized as nnq
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
@ -1,7 +1,6 @@
|
||||
import numpy
|
||||
from pt import configs
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import configs
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.quantized as nnq
|
||||
|
@ -1,6 +1,5 @@
|
||||
from pt import configs
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
from pt import configs
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.quantized as nnq
|
||||
|
@ -913,6 +913,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/disable_torch_function.cpp",
|
||||
"torch/csrc/utils/verbose.cpp",
|
||||
"torch/csrc/cpu/Module.cpp",
|
||||
"torch/csrc/instruction_counter/Module.cpp",
|
||||
] + lazy_tensor_core_python_sources
|
||||
|
||||
libtorch_python_distributed_core_sources = [
|
||||
|
@ -70,6 +70,7 @@
|
||||
#include <torch/csrc/functorch/init.h>
|
||||
#include <torch/csrc/fx/node.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/pybind.h>
|
||||
#include <torch/csrc/instruction_counter/Module.h>
|
||||
#include <torch/csrc/jit/python/init.h>
|
||||
#include <torch/csrc/jit/python/python_ir.h>
|
||||
#include <torch/csrc/jit/python/python_tracer.h>
|
||||
@ -1698,6 +1699,7 @@ PyObject* initModule() {
|
||||
#endif
|
||||
torch::mtia::initModule(module);
|
||||
torch::cpu::initModule(module);
|
||||
torch::instruction_counter::initModule(module);
|
||||
torch::initVerboseBindings(module);
|
||||
ASSERT_TRUE(THPStorage_init(module));
|
||||
|
||||
|
86
torch/csrc/instruction_counter/Module.cpp
Normal file
86
torch/csrc/instruction_counter/Module.cpp
Normal file
@ -0,0 +1,86 @@
|
||||
#include <torch/csrc/instruction_counter/Module.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <cerrno>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <linux/perf_event.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace torch::instruction_counter {
|
||||
|
||||
long start() {
|
||||
#if !defined(__linux__)
|
||||
throw std::runtime_error("This systems seems not to be Linux");
|
||||
#else
|
||||
|
||||
// Construct base perf_event_attr struct
|
||||
perf_event_attr attr{};
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.size = sizeof(attr);
|
||||
attr.exclude_kernel = 1;
|
||||
attr.disabled = 1;
|
||||
attr.exclude_hv = 1;
|
||||
attr.sample_period = 0;
|
||||
// Enable hardware counting
|
||||
attr.type = PERF_TYPE_HARDWARE;
|
||||
attr.config = PERF_COUNT_HW_INSTRUCTIONS;
|
||||
|
||||
long fd = syscall(SYS_perf_event_open, &attr, 0, -1, -1, 0);
|
||||
if (fd == -1) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"Failed to open instruction count event: %s.\n",
|
||||
strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
ioctl((int)fd, PERF_EVENT_IOC_RESET, 0); // Reset the counter
|
||||
ioctl((int)fd, PERF_EVENT_IOC_ENABLE, 0); // Enable the counter
|
||||
return fd;
|
||||
#endif
|
||||
}
|
||||
|
||||
uint64_t end(int fd) {
|
||||
#if !defined(__linux__)
|
||||
throw std::runtime_error("This systems seems not to be Linux");
|
||||
#else
|
||||
// Disable the event group
|
||||
if (ioctl(fd, PERF_EVENT_IOC_DISABLE, PERF_IOC_FLAG_GROUP) == -1) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"Error disabling perf event (fd: %d): %s\n",
|
||||
fd,
|
||||
strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
|
||||
uint64_t total_instructions = 0;
|
||||
|
||||
// Read results
|
||||
long ret_val = read(fd, &total_instructions, sizeof(total_instructions));
|
||||
if (ret_val == -1) {
|
||||
fprintf(stderr, "Error reading perf event results: %s\n", strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
|
||||
close(fd);
|
||||
return total_instructions;
|
||||
#endif
|
||||
}
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
auto instruction_counter = m.def_submodule(
|
||||
"_instruction_counter", "instruction_counter related pybind.");
|
||||
instruction_counter.def("start", start);
|
||||
instruction_counter.def("end", end);
|
||||
}
|
||||
|
||||
} // namespace torch::instruction_counter
|
8
torch/csrc/instruction_counter/Module.h
Normal file
8
torch/csrc/instruction_counter/Module.h
Normal file
@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch::instruction_counter {
|
||||
|
||||
void initModule(PyObject* module);
|
||||
|
||||
} // namespace torch::instruction_counter
|
Reference in New Issue
Block a user