mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Matmul recipe into x86_inductor_quantizer (#122776)
**Summary** Add `matmul` in the quantization recipes, noting that it's not a general recipe but tailored to meet accuracy criteria for specific models. `matmul` recipe is disabled by default. **Test Plan** ``` python -m pytest quantization/pt2e/test_x86inductor_quantizer.py -k test_attention_block ``` Differential Revision: [D56288468](https://our.internmc.facebook.com/intern/diff/D56288468) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122776 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
1fcdea8cd6
commit
dd440ac734
@ -359,21 +359,45 @@ class TestHelperModules:
|
||||
return tmp + self.bn2(self.conv2(tmp))
|
||||
|
||||
class SelfAttnLikeModule(torch.nn.Module):
|
||||
def __init__(self, input_dim) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
transpose_for_score=False,
|
||||
num_attention_heads=None,
|
||||
attention_head_size=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.q_proj = nn.Linear(input_dim, input_dim, bias=False)
|
||||
self.k_proj = nn.Linear(input_dim, input_dim, bias=False)
|
||||
self.v_proj = nn.Linear(input_dim, input_dim, bias=False)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.transpose_for_score = transpose_for_score
|
||||
if self.transpose_for_score:
|
||||
assert num_attention_heads is not None
|
||||
assert attention_head_size is not None
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_size = attention_head_size
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, x):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
scores = torch.bmm(q, k.transpose(1, 2)) / (self.input_dim**0.5)
|
||||
if self.transpose_for_score:
|
||||
q = self.transpose_for_scores(q)
|
||||
k = self.transpose_for_scores(k)
|
||||
v = self.transpose_for_scores(v)
|
||||
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
|
||||
attention = self.softmax(scores)
|
||||
weighted = torch.bmm(attention, v)
|
||||
weighted = torch.matmul(attention, v)
|
||||
return weighted
|
||||
|
||||
|
||||
@ -1402,7 +1426,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
torch.ops.quantized_decomposed.choose_qparams.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
torch.ops.aten.linear.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
@ -1438,7 +1461,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
torch.ops.quantized_decomposed.choose_qparams.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
torch.ops.aten.linear.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
@ -1551,3 +1573,72 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_attention_block(self):
|
||||
"""
|
||||
Test pattern of Attention like Block with X86InductorQuantizer.
|
||||
"""
|
||||
for annotate_matmul in [False, True]:
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
m = TestHelperModules.SelfAttnLikeModule(
|
||||
input_dim=64 * 16,
|
||||
transpose_for_score=True,
|
||||
num_attention_heads=16,
|
||||
attention_head_size=64,
|
||||
).eval()
|
||||
example_inputs = (torch.randn(2, 384, 1024),)
|
||||
|
||||
m(*example_inputs)
|
||||
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
|
||||
if annotate_matmul:
|
||||
quantizer.set_function_type_qconfig(
|
||||
torch.matmul, quantizer.get_global_quantization_config()
|
||||
)
|
||||
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5
|
||||
if annotate_matmul
|
||||
else 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7
|
||||
if annotate_matmul
|
||||
else 3,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||
}
|
||||
if annotate_matmul:
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.view.default,
|
||||
torch.ops.aten.permute.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.matmul.default,
|
||||
torch.ops.aten.div.Tensor,
|
||||
torch.ops.aten.softmax.int,
|
||||
]
|
||||
else:
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.view.default,
|
||||
torch.ops.aten.permute.default,
|
||||
torch.ops.aten.matmul.default,
|
||||
torch.ops.aten.div.Tensor,
|
||||
torch.ops.aten.softmax.int,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
@ -82,7 +82,9 @@ default_quantizable_ops = propagation_quantizable_ops | {
|
||||
|
||||
# A superset of default_quantizable_ops includes operators support the int8 data type
|
||||
# but not enabled by default recipe of X86InductorQuantizer.
|
||||
quantizable_ops = default_quantizable_ops
|
||||
quantizable_ops = default_quantizable_ops | {
|
||||
torch.ops.aten.matmul.default,
|
||||
}
|
||||
|
||||
QUANT_ANNOTATION_KEY = "quantization_annotation"
|
||||
|
||||
@ -110,6 +112,12 @@ def _map_module_function_to_aten_operator_type():
|
||||
],
|
||||
torch.ops.aten.flatten.using_ints,
|
||||
),
|
||||
(
|
||||
[
|
||||
torch.matmul,
|
||||
],
|
||||
torch.ops.aten.matmul.default,
|
||||
),
|
||||
)
|
||||
for map_item in map_list:
|
||||
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload]
|
||||
@ -310,6 +318,14 @@ class X86InductorQuantizer(Quantizer):
|
||||
self.global_config = quantization_config
|
||||
return self
|
||||
|
||||
def get_global_quantization_config(self):
|
||||
if not isinstance(self.global_config, QuantizationConfig):
|
||||
warnings.warn(
|
||||
"The global_config for X86InductorQuantizer is currently invalid. \
|
||||
Please ensure that you use set_global to establish the global quantization configuration."
|
||||
)
|
||||
return self.global_config
|
||||
|
||||
def set_function_type_qconfig(
|
||||
self,
|
||||
function_type: Callable,
|
||||
@ -499,6 +515,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
# Step1: Recipe of fusion patterns like conv/linear.
|
||||
self._annotate_conv2d_fusion_pattern(model)
|
||||
self._annotate_linear_fusion_pattern(model)
|
||||
self._annotate_matmul(model)
|
||||
|
||||
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
|
||||
# Go through all the nodes from start to end.
|
||||
@ -752,6 +769,24 @@ class X86InductorQuantizer(Quantizer):
|
||||
self._annotate_linear_unary(model, config)
|
||||
self._annotate_linear(model, config)
|
||||
|
||||
def _annotate_matmul(self, model: torch.fx.GraphModule):
|
||||
if config := self._get_aten_operator_qconfig(torch.ops.aten.matmul.default):
|
||||
for node in model.graph.nodes:
|
||||
if node.target == torch.ops.aten.matmul.default and not _is_annotated(
|
||||
[node]
|
||||
):
|
||||
input_qspec_map = {}
|
||||
matmul_node = node
|
||||
for input_node in matmul_node.args:
|
||||
input_qspec_map[input_node] = get_input_act_qspec(config)
|
||||
matmul_node.meta[
|
||||
QUANT_ANNOTATION_KEY
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d_binary_unary(
|
||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user