mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is to enable operator benchmark for CPU to track op level performance. This PR is motivated by PR: https://github.com/pytorch/pytorch/issues/120982 and investigate feasibility in https://github.com/pytorch/pytorch/pull/127216 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143733 Approved by: https://github.com/leslie-fang-intel, https://github.com/atalman, https://github.com/huydhn, https://github.com/malfet Co-authored-by: diwei sun <diwei.sun@intel.com> Co-authored-by: chuanqiw <chuanqi.wang@intel.com>
387 lines
12 KiB
Python
387 lines
12 KiB
Python
import operator_benchmark as op_bench
|
|
|
|
import torch
|
|
import torch.ao.nn.quantized as nnq
|
|
import torch.ao.quantization as tq
|
|
import torch.nn as nn
|
|
|
|
|
|
"""Microbenchmarks for general quantization operations."""
|
|
|
|
# mode is used to show the direction of the benchmark:
|
|
# if 'Q', benchmark quantization, else dequantization
|
|
|
|
quantize_configs_short_dict = {
|
|
"attr_names": ["C", "M", "N", "dtype", "mode"],
|
|
"attrs": [
|
|
[3, 512, 512, torch.quint8, "Q"],
|
|
[3, 512, 512, torch.quint8, "D"],
|
|
],
|
|
"tags": ["short"],
|
|
}
|
|
|
|
quantize_configs_long_dict = {
|
|
"C": [3, 5, 8], # this is reused for per-channel: avoid single channel test
|
|
"M": [256, 1024],
|
|
"N": [256, 1024],
|
|
"dtype": [torch.quint8, torch.qint8, torch.qint32],
|
|
"mode": ["D", "Q"],
|
|
"tags": ["long"],
|
|
}
|
|
|
|
|
|
quantize_per_tensor_configs_short = op_bench.config_list(**quantize_configs_short_dict)
|
|
|
|
quantize_per_tensor_configs_long = op_bench.cross_product_configs(
|
|
**quantize_configs_long_dict
|
|
)
|
|
|
|
|
|
class QuantizePerTensorBenchmark(op_bench.TorchBenchmarkBase):
|
|
r"""Benchmarks both quantization and dequantization."""
|
|
|
|
def init(self, C, M, N, dtype, mode):
|
|
assert mode in ("Q", "D")
|
|
self.input = torch.rand(C, M, N)
|
|
self.dtype = dtype
|
|
self.op = nnq.Quantize(scale=1.0, zero_point=0, dtype=dtype)
|
|
self.set_module_name("QuantizePerTensor")
|
|
|
|
if mode == "D":
|
|
self.input = self.op(self.input)
|
|
self.op = nnq.DeQuantize()
|
|
self.set_module_name("DequantizePerTensor")
|
|
|
|
self.inputs = {"input": self.input}
|
|
|
|
def forward(self, input):
|
|
return self.op(input)
|
|
|
|
|
|
op_bench.generate_pt_test(
|
|
quantize_per_tensor_configs_short + quantize_per_tensor_configs_long,
|
|
QuantizePerTensorBenchmark,
|
|
)
|
|
|
|
# === Per Channel quantization ===
|
|
|
|
quantize_per_channel_configs_short = op_bench.config_list(
|
|
cross_product_configs={"axis": (0,)}, **quantize_configs_short_dict
|
|
)
|
|
|
|
quantize_per_channel_configs_long = op_bench.cross_product_configs(
|
|
axis=(0, 1, 2), **quantize_configs_long_dict
|
|
)
|
|
|
|
|
|
class QuantizePerChannelBenchmark(op_bench.TorchBenchmarkBase):
|
|
r"""Benchmarks both quantization and dequantization."""
|
|
|
|
def init(self, C, M, N, dtype, axis, mode):
|
|
assert mode in ("Q", "D")
|
|
self.input = torch.rand(C, M, N)
|
|
self.op = torch.quantize_per_channel
|
|
|
|
channel_len = (C, M, N)[axis]
|
|
|
|
self.kwargs = {
|
|
"scales": torch.tensor([1.0] * channel_len),
|
|
"zero_points": torch.tensor([0] * channel_len),
|
|
"dtype": dtype,
|
|
"axis": axis,
|
|
}
|
|
|
|
self.set_module_name("QuantizePerChannel")
|
|
|
|
if mode == "D":
|
|
self.input = self.op(self.input, **self.kwargs)
|
|
|
|
def dequant(input, scales, zero_points, axis: int, dtype: int):
|
|
return input.dequantize()
|
|
|
|
self.op = dequant
|
|
self.set_module_name("DequantizePerChannel")
|
|
|
|
self.inputs = {
|
|
"input": self.input,
|
|
"scales": torch.tensor([1.0] * channel_len),
|
|
"zero_points": torch.tensor([0] * channel_len),
|
|
"axis": axis,
|
|
"dtype": dtype,
|
|
}
|
|
|
|
def forward(self, input, scales, zero_points, axis: int, dtype: int):
|
|
return self.op(
|
|
input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype
|
|
)
|
|
|
|
|
|
op_bench.generate_pt_test(
|
|
quantize_per_channel_configs_short + quantize_per_channel_configs_long,
|
|
QuantizePerChannelBenchmark,
|
|
)
|
|
|
|
# === Fake Quantization ===
|
|
# Generated benchmarks names start with 'learnable_kernel' or 'original_kernel',
|
|
# for ex. 'original_kernel_nbits8_cpu_N1_C1_H256_W256_zero_point_dtypetorch.int32_bwdall'
|
|
|
|
fake_quantize_configs_short_dict = {
|
|
"attr_names": ["N", "C", "H", "W", "zero_point_dtype"],
|
|
"attrs": [
|
|
[1, 3, 512, 512, torch.int32],
|
|
],
|
|
"tags": ["short"],
|
|
}
|
|
|
|
fake_quantize_configs_long_dict = {
|
|
"N": [1],
|
|
"C": [1, 3, 8, 32],
|
|
"H": [256, 1024],
|
|
"W": [256, 1024],
|
|
"zero_point_dtype": [torch.int32],
|
|
"tags": ["long"],
|
|
}
|
|
|
|
fake_quantize_configs_short = op_bench.config_list(
|
|
cross_product_configs={
|
|
"device": ("cpu", "cuda"),
|
|
},
|
|
**fake_quantize_configs_short_dict,
|
|
)
|
|
|
|
fake_quantize_configs_long = op_bench.cross_product_configs(
|
|
device=("cpu", "cuda"), **fake_quantize_configs_long_dict
|
|
)
|
|
|
|
|
|
class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase):
|
|
r"""Benchmarks fake quantization with default parameters."""
|
|
|
|
def init(self, N, C, H, W, zero_point_dtype, device):
|
|
self.inputs = {"input": torch.rand(N, C, H, W).to(device)}
|
|
self.op = tq.FakeQuantize().to(device)
|
|
self.set_module_name("FakeQuantize")
|
|
|
|
def forward(self, input):
|
|
return self.op(input)
|
|
|
|
|
|
op_bench.generate_pt_test(
|
|
fake_quantize_configs_short + fake_quantize_configs_long, FakeQuantizeBenchmark
|
|
)
|
|
|
|
|
|
# op_type is used to describe the type of operator used in benchmarking:
|
|
# learnable_kernel represents the c++ kernel that can backpropagate on
|
|
# scale and zero point.
|
|
# original_kernel represents the original fake quantize c++ kernel.
|
|
|
|
|
|
def fakeQuantizePerTensorLearnableKernel(
|
|
input, scale, zero_point, quant_min: int, quant_max: int
|
|
):
|
|
return torch._fake_quantize_learnable_per_tensor_affine(
|
|
input, scale, zero_point, quant_min, quant_max
|
|
)
|
|
|
|
|
|
def fakeQuantizePerTensorOriginalKernel(
|
|
input, scale, zero_point, quant_min: int, quant_max: int
|
|
):
|
|
return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max)
|
|
|
|
|
|
fake_quantize_per_tensor_ops = op_bench.op_list(
|
|
attrs=(
|
|
("learnable_kernel_tensor", fakeQuantizePerTensorLearnableKernel),
|
|
("original_kernel_tensor", fakeQuantizePerTensorOriginalKernel),
|
|
),
|
|
attr_names=("op_name", "op_func"),
|
|
)
|
|
|
|
fake_quantize_operator_configs_short = op_bench.config_list(
|
|
cross_product_configs={
|
|
"nbits": (4, 8),
|
|
"device": ("cpu", "cuda"),
|
|
},
|
|
**fake_quantize_configs_short_dict,
|
|
)
|
|
|
|
fake_quantize_operator_configs_long = op_bench.cross_product_configs(
|
|
nbits=(4, 8), device=("cpu", "cuda"), **fake_quantize_configs_long_dict
|
|
)
|
|
|
|
# TODO(future PR) Combine config for floating point zero_point with other configs, once it is
|
|
# fully supported in all fakeQuant operators and devices for
|
|
# https://github.com/pytorch/pytorch/issues/61866.
|
|
fake_quantize_configs_long_dict_float_zero_point = (
|
|
fake_quantize_configs_long_dict.copy()
|
|
)
|
|
fake_quantize_configs_long_dict_float_zero_point["zero_point_dtype"] = [
|
|
torch.float32,
|
|
torch.half,
|
|
]
|
|
|
|
fake_quantize_operator_configs_long_float_zero_point = op_bench.cross_product_configs(
|
|
nbits=(8,),
|
|
device=("cpu", "cuda"),
|
|
**fake_quantize_configs_long_dict_float_zero_point,
|
|
)
|
|
|
|
|
|
class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase):
|
|
r"""Benchmarks 3 different fake quantize per tensor operators."""
|
|
|
|
def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
|
|
self.quant_min = 0
|
|
self.quant_max = 2**nbits - 1
|
|
self.quant_range = 2**nbits
|
|
self.input = nn.Parameter(
|
|
torch.rand(N, C, H, W, dtype=torch.float, device=device),
|
|
requires_grad=self.auto_set(),
|
|
)
|
|
self.scale = nn.Parameter(
|
|
torch.tensor([1.0]).to(device), requires_grad=self.auto_set()
|
|
)
|
|
if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
|
|
self.zero_point = nn.Parameter(
|
|
torch.tensor([0.0]).to(device).to(zero_point_dtype),
|
|
requires_grad=self.auto_set(),
|
|
)
|
|
else:
|
|
self.zero_point = nn.Parameter(
|
|
torch.tensor([0.0]).to(device), requires_grad=self.auto_set()
|
|
)
|
|
|
|
self.inputs = {
|
|
"input": self.input,
|
|
"scale": self.scale,
|
|
"zero_point": self.zero_point,
|
|
"quant_min": self.quant_min,
|
|
"quant_max": self.quant_max,
|
|
}
|
|
self.op_func = op_func
|
|
|
|
def forward(self, input, scale, zero_point, quant_min: int, quant_max: int):
|
|
return self.op_func(input, scale, zero_point, quant_min, quant_max)
|
|
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
fake_quantize_per_tensor_ops,
|
|
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
|
|
FakeQuantizePerTensorBaseOpBenchmark,
|
|
)
|
|
|
|
op_bench.generate_pt_gradient_tests_from_op_list(
|
|
fake_quantize_per_tensor_ops,
|
|
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
|
|
FakeQuantizePerTensorBaseOpBenchmark,
|
|
)
|
|
|
|
|
|
def fakeQuantizePerChannelLearnableKernel(
|
|
input, scale, zero_point, axis: int, quant_min: int, quant_max: int
|
|
):
|
|
return torch._fake_quantize_learnable_per_channel_affine(
|
|
input, scale, zero_point, axis, quant_min, quant_max
|
|
)
|
|
|
|
|
|
def fakeQuantizePerChannelOriginalKernel(
|
|
input, scale, zero_point, axis: int, quant_min: int, quant_max: int
|
|
):
|
|
return torch.fake_quantize_per_channel_affine(
|
|
input, scale, zero_point, axis, quant_min, quant_max
|
|
)
|
|
|
|
|
|
fake_quantize_per_channel_ops = op_bench.op_list(
|
|
attrs=(
|
|
("learnable_kernel_channel", fakeQuantizePerChannelLearnableKernel),
|
|
("original_kernel_channel", fakeQuantizePerChannelOriginalKernel),
|
|
),
|
|
attr_names=("op_name", "op_func"),
|
|
)
|
|
|
|
fake_quantize_per_channel_float_zero_point_ops = op_bench.op_list(
|
|
attrs=(("original_kernel", fakeQuantizePerChannelOriginalKernel),),
|
|
attr_names=("op_name", "op_func"),
|
|
)
|
|
|
|
|
|
class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase):
|
|
r"""Benchmarks 3 different fake quantize per channel operators."""
|
|
|
|
def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
|
|
self.quant_min = 0
|
|
self.quant_max = 2**nbits - 1
|
|
self.quant_range = 2**nbits
|
|
# Axis is chosen with respect to the number of channels: C.
|
|
self.axis = 1
|
|
self.input = nn.Parameter(
|
|
torch.rand(
|
|
N,
|
|
C,
|
|
H,
|
|
W,
|
|
dtype=torch.float,
|
|
device=device,
|
|
requires_grad=self.auto_set(),
|
|
)
|
|
)
|
|
|
|
if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
|
|
self.scale = torch.ones(
|
|
C, device=device, dtype=torch.float32, requires_grad=False
|
|
)
|
|
self.zero_point = torch.zeros(
|
|
C, device=device, dtype=zero_point_dtype, requires_grad=False
|
|
)
|
|
else:
|
|
self.scale = nn.Parameter(
|
|
torch.ones(C, device=device, dtype=torch.float32),
|
|
requires_grad=self.auto_set(),
|
|
)
|
|
self.zero_point = nn.Parameter(
|
|
torch.zeros(C, device=device, dtype=torch.float32),
|
|
requires_grad=self.auto_set(),
|
|
)
|
|
|
|
self.inputs = {
|
|
"input": self.input,
|
|
"scale": self.scale,
|
|
"zero_point": self.zero_point,
|
|
"axis": self.axis,
|
|
"quant_min": self.quant_min,
|
|
"quant_max": self.quant_max,
|
|
}
|
|
|
|
self.op_func = op_func
|
|
|
|
def forward(
|
|
self, input, scale, zero_point, axis: int, quant_min: int, quant_max: int
|
|
):
|
|
return self.op_func(input, scale, zero_point, axis, quant_min, quant_max)
|
|
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
fake_quantize_per_channel_ops,
|
|
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
|
|
FakeQuantizePerChannelOpBenchmark,
|
|
)
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
fake_quantize_per_channel_float_zero_point_ops,
|
|
fake_quantize_operator_configs_long_float_zero_point,
|
|
FakeQuantizePerChannelOpBenchmark,
|
|
)
|
|
|
|
op_bench.generate_pt_gradient_tests_from_op_list(
|
|
fake_quantize_per_channel_ops,
|
|
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
|
|
FakeQuantizePerChannelOpBenchmark,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
op_bench.benchmark_runner.main()
|