Files
pytorch/benchmarks/operator_benchmark/pt/quantization_test.py
LifengWang fa5f556f88 [CI] enable operator benchmark on CPU (#143733)
This is to enable operator benchmark for CPU to track op level performance. This PR is motivated by PR: https://github.com/pytorch/pytorch/issues/120982 and investigate feasibility in https://github.com/pytorch/pytorch/pull/127216

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143733
Approved by: https://github.com/leslie-fang-intel, https://github.com/atalman, https://github.com/huydhn, https://github.com/malfet

Co-authored-by: diwei sun <diwei.sun@intel.com>
Co-authored-by: chuanqiw <chuanqi.wang@intel.com>
2025-03-21 16:46:03 +00:00

387 lines
12 KiB
Python

import operator_benchmark as op_bench
import torch
import torch.ao.nn.quantized as nnq
import torch.ao.quantization as tq
import torch.nn as nn
"""Microbenchmarks for general quantization operations."""
# mode is used to show the direction of the benchmark:
# if 'Q', benchmark quantization, else dequantization
quantize_configs_short_dict = {
"attr_names": ["C", "M", "N", "dtype", "mode"],
"attrs": [
[3, 512, 512, torch.quint8, "Q"],
[3, 512, 512, torch.quint8, "D"],
],
"tags": ["short"],
}
quantize_configs_long_dict = {
"C": [3, 5, 8], # this is reused for per-channel: avoid single channel test
"M": [256, 1024],
"N": [256, 1024],
"dtype": [torch.quint8, torch.qint8, torch.qint32],
"mode": ["D", "Q"],
"tags": ["long"],
}
quantize_per_tensor_configs_short = op_bench.config_list(**quantize_configs_short_dict)
quantize_per_tensor_configs_long = op_bench.cross_product_configs(
**quantize_configs_long_dict
)
class QuantizePerTensorBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks both quantization and dequantization."""
def init(self, C, M, N, dtype, mode):
assert mode in ("Q", "D")
self.input = torch.rand(C, M, N)
self.dtype = dtype
self.op = nnq.Quantize(scale=1.0, zero_point=0, dtype=dtype)
self.set_module_name("QuantizePerTensor")
if mode == "D":
self.input = self.op(self.input)
self.op = nnq.DeQuantize()
self.set_module_name("DequantizePerTensor")
self.inputs = {"input": self.input}
def forward(self, input):
return self.op(input)
op_bench.generate_pt_test(
quantize_per_tensor_configs_short + quantize_per_tensor_configs_long,
QuantizePerTensorBenchmark,
)
# === Per Channel quantization ===
quantize_per_channel_configs_short = op_bench.config_list(
cross_product_configs={"axis": (0,)}, **quantize_configs_short_dict
)
quantize_per_channel_configs_long = op_bench.cross_product_configs(
axis=(0, 1, 2), **quantize_configs_long_dict
)
class QuantizePerChannelBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks both quantization and dequantization."""
def init(self, C, M, N, dtype, axis, mode):
assert mode in ("Q", "D")
self.input = torch.rand(C, M, N)
self.op = torch.quantize_per_channel
channel_len = (C, M, N)[axis]
self.kwargs = {
"scales": torch.tensor([1.0] * channel_len),
"zero_points": torch.tensor([0] * channel_len),
"dtype": dtype,
"axis": axis,
}
self.set_module_name("QuantizePerChannel")
if mode == "D":
self.input = self.op(self.input, **self.kwargs)
def dequant(input, scales, zero_points, axis: int, dtype: int):
return input.dequantize()
self.op = dequant
self.set_module_name("DequantizePerChannel")
self.inputs = {
"input": self.input,
"scales": torch.tensor([1.0] * channel_len),
"zero_points": torch.tensor([0] * channel_len),
"axis": axis,
"dtype": dtype,
}
def forward(self, input, scales, zero_points, axis: int, dtype: int):
return self.op(
input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype
)
op_bench.generate_pt_test(
quantize_per_channel_configs_short + quantize_per_channel_configs_long,
QuantizePerChannelBenchmark,
)
# === Fake Quantization ===
# Generated benchmarks names start with 'learnable_kernel' or 'original_kernel',
# for ex. 'original_kernel_nbits8_cpu_N1_C1_H256_W256_zero_point_dtypetorch.int32_bwdall'
fake_quantize_configs_short_dict = {
"attr_names": ["N", "C", "H", "W", "zero_point_dtype"],
"attrs": [
[1, 3, 512, 512, torch.int32],
],
"tags": ["short"],
}
fake_quantize_configs_long_dict = {
"N": [1],
"C": [1, 3, 8, 32],
"H": [256, 1024],
"W": [256, 1024],
"zero_point_dtype": [torch.int32],
"tags": ["long"],
}
fake_quantize_configs_short = op_bench.config_list(
cross_product_configs={
"device": ("cpu", "cuda"),
},
**fake_quantize_configs_short_dict,
)
fake_quantize_configs_long = op_bench.cross_product_configs(
device=("cpu", "cuda"), **fake_quantize_configs_long_dict
)
class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks fake quantization with default parameters."""
def init(self, N, C, H, W, zero_point_dtype, device):
self.inputs = {"input": torch.rand(N, C, H, W).to(device)}
self.op = tq.FakeQuantize().to(device)
self.set_module_name("FakeQuantize")
def forward(self, input):
return self.op(input)
op_bench.generate_pt_test(
fake_quantize_configs_short + fake_quantize_configs_long, FakeQuantizeBenchmark
)
# op_type is used to describe the type of operator used in benchmarking:
# learnable_kernel represents the c++ kernel that can backpropagate on
# scale and zero point.
# original_kernel represents the original fake quantize c++ kernel.
def fakeQuantizePerTensorLearnableKernel(
input, scale, zero_point, quant_min: int, quant_max: int
):
return torch._fake_quantize_learnable_per_tensor_affine(
input, scale, zero_point, quant_min, quant_max
)
def fakeQuantizePerTensorOriginalKernel(
input, scale, zero_point, quant_min: int, quant_max: int
):
return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max)
fake_quantize_per_tensor_ops = op_bench.op_list(
attrs=(
("learnable_kernel_tensor", fakeQuantizePerTensorLearnableKernel),
("original_kernel_tensor", fakeQuantizePerTensorOriginalKernel),
),
attr_names=("op_name", "op_func"),
)
fake_quantize_operator_configs_short = op_bench.config_list(
cross_product_configs={
"nbits": (4, 8),
"device": ("cpu", "cuda"),
},
**fake_quantize_configs_short_dict,
)
fake_quantize_operator_configs_long = op_bench.cross_product_configs(
nbits=(4, 8), device=("cpu", "cuda"), **fake_quantize_configs_long_dict
)
# TODO(future PR) Combine config for floating point zero_point with other configs, once it is
# fully supported in all fakeQuant operators and devices for
# https://github.com/pytorch/pytorch/issues/61866.
fake_quantize_configs_long_dict_float_zero_point = (
fake_quantize_configs_long_dict.copy()
)
fake_quantize_configs_long_dict_float_zero_point["zero_point_dtype"] = [
torch.float32,
torch.half,
]
fake_quantize_operator_configs_long_float_zero_point = op_bench.cross_product_configs(
nbits=(8,),
device=("cpu", "cuda"),
**fake_quantize_configs_long_dict_float_zero_point,
)
class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks 3 different fake quantize per tensor operators."""
def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
self.quant_min = 0
self.quant_max = 2**nbits - 1
self.quant_range = 2**nbits
self.input = nn.Parameter(
torch.rand(N, C, H, W, dtype=torch.float, device=device),
requires_grad=self.auto_set(),
)
self.scale = nn.Parameter(
torch.tensor([1.0]).to(device), requires_grad=self.auto_set()
)
if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
self.zero_point = nn.Parameter(
torch.tensor([0.0]).to(device).to(zero_point_dtype),
requires_grad=self.auto_set(),
)
else:
self.zero_point = nn.Parameter(
torch.tensor([0.0]).to(device), requires_grad=self.auto_set()
)
self.inputs = {
"input": self.input,
"scale": self.scale,
"zero_point": self.zero_point,
"quant_min": self.quant_min,
"quant_max": self.quant_max,
}
self.op_func = op_func
def forward(self, input, scale, zero_point, quant_min: int, quant_max: int):
return self.op_func(input, scale, zero_point, quant_min, quant_max)
op_bench.generate_pt_tests_from_op_list(
fake_quantize_per_tensor_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerTensorBaseOpBenchmark,
)
op_bench.generate_pt_gradient_tests_from_op_list(
fake_quantize_per_tensor_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerTensorBaseOpBenchmark,
)
def fakeQuantizePerChannelLearnableKernel(
input, scale, zero_point, axis: int, quant_min: int, quant_max: int
):
return torch._fake_quantize_learnable_per_channel_affine(
input, scale, zero_point, axis, quant_min, quant_max
)
def fakeQuantizePerChannelOriginalKernel(
input, scale, zero_point, axis: int, quant_min: int, quant_max: int
):
return torch.fake_quantize_per_channel_affine(
input, scale, zero_point, axis, quant_min, quant_max
)
fake_quantize_per_channel_ops = op_bench.op_list(
attrs=(
("learnable_kernel_channel", fakeQuantizePerChannelLearnableKernel),
("original_kernel_channel", fakeQuantizePerChannelOriginalKernel),
),
attr_names=("op_name", "op_func"),
)
fake_quantize_per_channel_float_zero_point_ops = op_bench.op_list(
attrs=(("original_kernel", fakeQuantizePerChannelOriginalKernel),),
attr_names=("op_name", "op_func"),
)
class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks 3 different fake quantize per channel operators."""
def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
self.quant_min = 0
self.quant_max = 2**nbits - 1
self.quant_range = 2**nbits
# Axis is chosen with respect to the number of channels: C.
self.axis = 1
self.input = nn.Parameter(
torch.rand(
N,
C,
H,
W,
dtype=torch.float,
device=device,
requires_grad=self.auto_set(),
)
)
if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
self.scale = torch.ones(
C, device=device, dtype=torch.float32, requires_grad=False
)
self.zero_point = torch.zeros(
C, device=device, dtype=zero_point_dtype, requires_grad=False
)
else:
self.scale = nn.Parameter(
torch.ones(C, device=device, dtype=torch.float32),
requires_grad=self.auto_set(),
)
self.zero_point = nn.Parameter(
torch.zeros(C, device=device, dtype=torch.float32),
requires_grad=self.auto_set(),
)
self.inputs = {
"input": self.input,
"scale": self.scale,
"zero_point": self.zero_point,
"axis": self.axis,
"quant_min": self.quant_min,
"quant_max": self.quant_max,
}
self.op_func = op_func
def forward(
self, input, scale, zero_point, axis: int, quant_min: int, quant_max: int
):
return self.op_func(input, scale, zero_point, axis, quant_min, quant_max)
op_bench.generate_pt_tests_from_op_list(
fake_quantize_per_channel_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerChannelOpBenchmark,
)
op_bench.generate_pt_tests_from_op_list(
fake_quantize_per_channel_float_zero_point_ops,
fake_quantize_operator_configs_long_float_zero_point,
FakeQuantizePerChannelOpBenchmark,
)
op_bench.generate_pt_gradient_tests_from_op_list(
fake_quantize_per_channel_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerChannelOpBenchmark,
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()