mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126 Approved by: https://github.com/kit1980 ghstack dependencies: #127122, #127123, #127124, #127125
146 lines
4.2 KiB
Python
146 lines
4.2 KiB
Python
import operator_benchmark as op_bench
|
|
import torch
|
|
import torch.ao.quantization.observer as obs
|
|
|
|
qobserver_short_configs_dict = {
|
|
"attr_names": ("C", "M", "N", "dtype", "device"),
|
|
"attrs": (
|
|
(3, 512, 512, torch.quint8, "cpu"),
|
|
(3, 512, 512, torch.quint8, "cuda"),
|
|
),
|
|
"tags": ("short",),
|
|
}
|
|
|
|
q_hist_observer_short_configs_dict = {
|
|
"attr_names": ("C", "M", "N", "dtype", "device"),
|
|
"attrs": ((3, 512, 512, torch.quint8, "cpu"),),
|
|
"tags": ("short",),
|
|
}
|
|
|
|
qobserver_long_configs_dict = {
|
|
"C": (32, 64),
|
|
"M": (256, 1024),
|
|
"N": (256, 1024),
|
|
"device": ("cpu", "cuda"),
|
|
"dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same
|
|
"tags": ("long",),
|
|
}
|
|
|
|
q_hist_observer_long_configs_dict = {
|
|
"C": (1, 3, 8),
|
|
"M": (256, 1024),
|
|
"N": (256, 1024),
|
|
"device": ("cpu",),
|
|
"dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same
|
|
"tags": ("long",),
|
|
}
|
|
|
|
|
|
qobserver_per_tensor_configs_short = op_bench.config_list(
|
|
cross_product_configs={
|
|
"qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
|
|
},
|
|
**qobserver_short_configs_dict,
|
|
)
|
|
|
|
qobserver_per_tensor_configs_long = op_bench.cross_product_configs(
|
|
qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
|
|
**qobserver_long_configs_dict,
|
|
)
|
|
|
|
qobserver_per_channel_configs_short = op_bench.config_list(
|
|
cross_product_configs={
|
|
"qscheme": (torch.per_channel_affine, torch.per_channel_symmetric)
|
|
},
|
|
**qobserver_short_configs_dict,
|
|
)
|
|
|
|
qobserver_per_channel_configs_long = op_bench.cross_product_configs(
|
|
qscheme=(torch.per_channel_affine, torch.per_channel_symmetric),
|
|
**qobserver_long_configs_dict,
|
|
)
|
|
|
|
q_hist_observer_per_tensor_configs_short = op_bench.config_list(
|
|
cross_product_configs={
|
|
"qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
|
|
},
|
|
**q_hist_observer_short_configs_dict,
|
|
)
|
|
|
|
q_hist_observer_per_tensor_configs_long = op_bench.cross_product_configs(
|
|
qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
|
|
**q_hist_observer_long_configs_dict,
|
|
)
|
|
|
|
|
|
qobserver_per_tensor_list = op_bench.op_list(
|
|
attr_names=["op_name", "op_func"],
|
|
attrs=[
|
|
["MinMaxObserver", obs.MinMaxObserver],
|
|
["MovingAverageMinMaxObserver", obs.MovingAverageMinMaxObserver],
|
|
],
|
|
)
|
|
|
|
qobserver_per_channel_list = op_bench.op_list(
|
|
attr_names=["op_name", "op_func"],
|
|
attrs=[
|
|
["PerChannelMinMaxObserver", obs.PerChannelMinMaxObserver],
|
|
[
|
|
"MovingAveragePerChannelMinMaxObserver",
|
|
obs.MovingAveragePerChannelMinMaxObserver,
|
|
],
|
|
],
|
|
)
|
|
|
|
q_hist_observer_list = op_bench.op_list(
|
|
attr_names=["op_name", "op_func"],
|
|
attrs=[
|
|
["HistogramObserver", obs.HistogramObserver],
|
|
["HistogramObserverCalculateQparams", obs.HistogramObserver],
|
|
],
|
|
)
|
|
|
|
|
|
class QObserverBenchmark(op_bench.TorchBenchmarkBase):
|
|
def init(self, C, M, N, dtype, qscheme, op_func, device):
|
|
self.inputs = {"f_input": torch.rand(C, M, N, device=device)}
|
|
self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device)
|
|
|
|
def forward(self, f_input):
|
|
self.op_func(f_input)
|
|
return self.op_func.calculate_qparams()
|
|
|
|
|
|
class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase):
|
|
def init(self, C, M, N, dtype, qscheme, op_func, device):
|
|
self.f_input = torch.rand(C, M, N, device=device)
|
|
self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device)
|
|
self.q_observer(self.f_input)
|
|
self.inputs = {}
|
|
|
|
def forward(self):
|
|
return self.q_observer.calculate_qparams()
|
|
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
qobserver_per_tensor_list,
|
|
qobserver_per_tensor_configs_short + qobserver_per_tensor_configs_long,
|
|
QObserverBenchmark,
|
|
)
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
qobserver_per_channel_list,
|
|
qobserver_per_channel_configs_short + qobserver_per_channel_configs_long,
|
|
QObserverBenchmark,
|
|
)
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
q_hist_observer_list,
|
|
q_hist_observer_per_tensor_configs_short + q_hist_observer_per_tensor_configs_long,
|
|
QObserverBenchmarkCalculateQparams,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
op_bench.benchmark_runner.main()
|