Files
pytorch/benchmarks/operator_benchmark/pt/qobserver_test.py
Xuehai Pan 7763c83af6 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
2024-05-27 04:22:18 +00:00

146 lines
4.2 KiB
Python

import operator_benchmark as op_bench
import torch
import torch.ao.quantization.observer as obs
qobserver_short_configs_dict = {
"attr_names": ("C", "M", "N", "dtype", "device"),
"attrs": (
(3, 512, 512, torch.quint8, "cpu"),
(3, 512, 512, torch.quint8, "cuda"),
),
"tags": ("short",),
}
q_hist_observer_short_configs_dict = {
"attr_names": ("C", "M", "N", "dtype", "device"),
"attrs": ((3, 512, 512, torch.quint8, "cpu"),),
"tags": ("short",),
}
qobserver_long_configs_dict = {
"C": (32, 64),
"M": (256, 1024),
"N": (256, 1024),
"device": ("cpu", "cuda"),
"dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same
"tags": ("long",),
}
q_hist_observer_long_configs_dict = {
"C": (1, 3, 8),
"M": (256, 1024),
"N": (256, 1024),
"device": ("cpu",),
"dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same
"tags": ("long",),
}
qobserver_per_tensor_configs_short = op_bench.config_list(
cross_product_configs={
"qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
},
**qobserver_short_configs_dict,
)
qobserver_per_tensor_configs_long = op_bench.cross_product_configs(
qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
**qobserver_long_configs_dict,
)
qobserver_per_channel_configs_short = op_bench.config_list(
cross_product_configs={
"qscheme": (torch.per_channel_affine, torch.per_channel_symmetric)
},
**qobserver_short_configs_dict,
)
qobserver_per_channel_configs_long = op_bench.cross_product_configs(
qscheme=(torch.per_channel_affine, torch.per_channel_symmetric),
**qobserver_long_configs_dict,
)
q_hist_observer_per_tensor_configs_short = op_bench.config_list(
cross_product_configs={
"qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
},
**q_hist_observer_short_configs_dict,
)
q_hist_observer_per_tensor_configs_long = op_bench.cross_product_configs(
qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
**q_hist_observer_long_configs_dict,
)
qobserver_per_tensor_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["MinMaxObserver", obs.MinMaxObserver],
["MovingAverageMinMaxObserver", obs.MovingAverageMinMaxObserver],
],
)
qobserver_per_channel_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["PerChannelMinMaxObserver", obs.PerChannelMinMaxObserver],
[
"MovingAveragePerChannelMinMaxObserver",
obs.MovingAveragePerChannelMinMaxObserver,
],
],
)
q_hist_observer_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["HistogramObserver", obs.HistogramObserver],
["HistogramObserverCalculateQparams", obs.HistogramObserver],
],
)
class QObserverBenchmark(op_bench.TorchBenchmarkBase):
def init(self, C, M, N, dtype, qscheme, op_func, device):
self.inputs = {"f_input": torch.rand(C, M, N, device=device)}
self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device)
def forward(self, f_input):
self.op_func(f_input)
return self.op_func.calculate_qparams()
class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase):
def init(self, C, M, N, dtype, qscheme, op_func, device):
self.f_input = torch.rand(C, M, N, device=device)
self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device)
self.q_observer(self.f_input)
self.inputs = {}
def forward(self):
return self.q_observer.calculate_qparams()
op_bench.generate_pt_tests_from_op_list(
qobserver_per_tensor_list,
qobserver_per_tensor_configs_short + qobserver_per_tensor_configs_long,
QObserverBenchmark,
)
op_bench.generate_pt_tests_from_op_list(
qobserver_per_channel_list,
qobserver_per_channel_configs_short + qobserver_per_channel_configs_long,
QObserverBenchmark,
)
op_bench.generate_pt_tests_from_op_list(
q_hist_observer_list,
q_hist_observer_per_tensor_configs_short + q_hist_observer_per_tensor_configs_long,
QObserverBenchmarkCalculateQparams,
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()