Add Matmul recipe into x86_inductor_quantizer (#122776)

**Summary**
Add `matmul` in the quantization recipes, noting that it's not a general recipe but tailored to meet accuracy criteria for specific models. `matmul` recipe is disabled by default.

**Test Plan**
```
python -m pytest quantization/pt2e/test_x86inductor_quantizer.py -k test_attention_block
```

Differential Revision: [D56288468](https://our.internmc.facebook.com/intern/diff/D56288468)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122776
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
leslie-fang-intel
2024-04-17 13:27:32 +08:00
committed by PyTorch MergeBot
parent 1fcdea8cd6
commit dd440ac734
2 changed files with 132 additions and 6 deletions

View File

@ -359,21 +359,45 @@ class TestHelperModules:
return tmp + self.bn2(self.conv2(tmp))
class SelfAttnLikeModule(torch.nn.Module):
def __init__(self, input_dim) -> None:
def __init__(
self,
input_dim,
transpose_for_score=False,
num_attention_heads=None,
attention_head_size=None,
) -> None:
super().__init__()
self.input_dim = input_dim
self.q_proj = nn.Linear(input_dim, input_dim, bias=False)
self.k_proj = nn.Linear(input_dim, input_dim, bias=False)
self.v_proj = nn.Linear(input_dim, input_dim, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.transpose_for_score = transpose_for_score
if self.transpose_for_score:
assert num_attention_heads is not None
assert attention_head_size is not None
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, x):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
scores = torch.bmm(q, k.transpose(1, 2)) / (self.input_dim**0.5)
if self.transpose_for_score:
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
attention = self.softmax(scores)
weighted = torch.bmm(attention, v)
weighted = torch.matmul(attention, v)
return weighted
@ -1402,7 +1426,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.choose_qparams.tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.aten.linear.default,
]
self._test_quantizer(
@ -1438,7 +1461,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.choose_qparams.tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.aten.linear.default,
]
self._test_quantizer(
@ -1551,3 +1573,72 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence,
node_list,
)
@skipIfNoX86
def test_attention_block(self):
"""
Test pattern of Attention like Block with X86InductorQuantizer.
"""
for annotate_matmul in [False, True]:
with override_quantized_engine("x86"), torch.no_grad():
m = TestHelperModules.SelfAttnLikeModule(
input_dim=64 * 16,
transpose_for_score=True,
num_attention_heads=16,
attention_head_size=64,
).eval()
example_inputs = (torch.randn(2, 384, 1024),)
m(*example_inputs)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
if annotate_matmul:
quantizer.set_function_type_qconfig(
torch.matmul, quantizer.get_global_quantization_config()
)
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5
if annotate_matmul
else 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7
if annotate_matmul
else 3,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
}
if annotate_matmul:
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.view.default,
torch.ops.aten.permute.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.matmul.default,
torch.ops.aten.div.Tensor,
torch.ops.aten.softmax.int,
]
else:
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.view.default,
torch.ops.aten.permute.default,
torch.ops.aten.matmul.default,
torch.ops.aten.div.Tensor,
torch.ops.aten.softmax.int,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)

View File

@ -82,7 +82,9 @@ default_quantizable_ops = propagation_quantizable_ops | {
# A superset of default_quantizable_ops includes operators support the int8 data type
# but not enabled by default recipe of X86InductorQuantizer.
quantizable_ops = default_quantizable_ops
quantizable_ops = default_quantizable_ops | {
torch.ops.aten.matmul.default,
}
QUANT_ANNOTATION_KEY = "quantization_annotation"
@ -110,6 +112,12 @@ def _map_module_function_to_aten_operator_type():
],
torch.ops.aten.flatten.using_ints,
),
(
[
torch.matmul,
],
torch.ops.aten.matmul.default,
),
)
for map_item in map_list:
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload]
@ -310,6 +318,14 @@ class X86InductorQuantizer(Quantizer):
self.global_config = quantization_config
return self
def get_global_quantization_config(self):
if not isinstance(self.global_config, QuantizationConfig):
warnings.warn(
"The global_config for X86InductorQuantizer is currently invalid. \
Please ensure that you use set_global to establish the global quantization configuration."
)
return self.global_config
def set_function_type_qconfig(
self,
function_type: Callable,
@ -499,6 +515,7 @@ class X86InductorQuantizer(Quantizer):
# Step1: Recipe of fusion patterns like conv/linear.
self._annotate_conv2d_fusion_pattern(model)
self._annotate_linear_fusion_pattern(model)
self._annotate_matmul(model)
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
# Go through all the nodes from start to end.
@ -752,6 +769,24 @@ class X86InductorQuantizer(Quantizer):
self._annotate_linear_unary(model, config)
self._annotate_linear(model, config)
def _annotate_matmul(self, model: torch.fx.GraphModule):
if config := self._get_aten_operator_qconfig(torch.ops.aten.matmul.default):
for node in model.graph.nodes:
if node.target == torch.ops.aten.matmul.default and not _is_annotated(
[node]
):
input_qspec_map = {}
matmul_node = node
for input_node in matmul_node.args:
input_qspec_map[input_node] = get_input_act_qspec(config)
matmul_node.meta[
QUANT_ANNOTATION_KEY
] = _X86InductorQuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
_is_output_of_quantized_pattern=True,
)
def _annotate_conv2d_binary_unary(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None: