[Quant][PT2E][X86] annotate and convert for linear_dynamic_fp16 (#141480)

Annotate linear node for `linear_dynamic_fp16` with `X86InductorQuantizer`
After `convert_pt2e`, the pattern will be
```
  x
  |
linear <- to_fp32 <- to_fp16 <- w
```

**Test plan**
```
pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_dynamic_fp16
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141480
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2024-11-28 17:01:02 -08:00
committed by PyTorch MergeBot
parent b7a45dbae3
commit 9827d677b4
5 changed files with 87 additions and 1 deletions

View File

@ -1804,6 +1804,41 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
}
self._check_annotation_stat(fq_m, expected_annotation_stat)
@skipIfNoX86
def test_linear_dynamic_fp16(self):
"""
Test pattern of linear_dynamic_fp16.
"""
with override_quantized_engine("x86"), torch.no_grad():
for use_bias in [True, False]:
m = TestHelperModules.SingleLinearModule(use_bias).eval()
example_inputs = (torch.randn(2, 4),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
quantizer.set_module_type_qconfig(
torch.nn.Linear, xiq.get_x86_inductor_linear_dynamic_fp16_config()
)
node_occurrence = {
# 2 convert_element_type nodes are inserted for weight
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
torch.ops.quantized_decomposed.convert_element_type.no_fuse: 2,
}
node_list = [
torch.ops.quantized_decomposed.convert_element_type.no_fuse,
torch.ops.aten.linear.default,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
@skipIfTorchDynamo("very slow")
@skipIfNoX86
def test_qat_conv2d(self):

View File

@ -1185,3 +1185,22 @@ def fake_quant_per_channel_meta(
quant_max: int,
) -> torch.Tensor:
return torch.empty_like(input)
quantized_decomposed_lib.define(
"convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor"
)
@impl(
quantized_decomposed_lib,
"convert_element_type.no_fuse",
"CompositeExplicitAutograd",
)
def convert_element_type(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return torch.ops.prims.convert_element_type.default(input, dtype)
@impl(quantized_decomposed_lib, "convert_element_type.no_fuse", "Meta")
def convert_element_type_meta(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return torch.empty_like(input, dtype=dtype)

View File

@ -342,7 +342,18 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
]
graph.erase_node(node)
elif dtype == torch.float16:
raise NotImplementedError("decomposed to float16 op not implemented yet")
# Insert to_fp16 -> to_fp32 node
dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
with graph.inserting_before(node):
input_node = node.args[0]
convert_fp16_node = graph.create_node(
"call_function", dtype_convert_op, (input_node, torch.float16), {}
)
convert_fp32_node = graph.create_node(
"call_function", dtype_convert_op, (convert_fp16_node, torch.float), {}
)
node.replace_all_uses_with(convert_fp32_node)
graph.erase_node(node)
# should not reach since we have checks in the beginning to make sure the
# activation_post_process is supported

View File

@ -62,6 +62,7 @@ if TYPE_CHECKING:
__all__ = [
"X86InductorQuantizer",
"get_default_x86_inductor_quantization_config",
"get_x86_inductor_linear_dynamic_fp16_config",
]
@ -341,6 +342,25 @@ def get_default_x86_inductor_quantization_config(
return quantization_config
@functools.lru_cache
def get_x86_inductor_linear_dynamic_fp16_config():
"""
For linear_dynamic_fp16. The name may be confusing.
The op's behavior is fp32_input * (fp16_weight -> to_fp32) -> fp32_output.
"""
weight_quantization_spec = QuantizationSpec(
dtype=torch.float16,
observer_or_fake_quant_ctr=PlaceholderObserver,
)
quantization_config = QuantizationConfig(
None, # input_quantization_spec
None, # output_quantization_spec
weight_quantization_spec,
None, # bias_quantization_spec
)
return quantization_config
def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None:
"""Annotate nodes to exclude them from quantization (their `quantization_config` is `None`)."""
if not isinstance(nodes, list):

View File

@ -152,6 +152,7 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_spec.qscheme not in [
torch.per_tensor_symmetric,
torch.per_channel_symmetric,
None,
]:
raise ValueError(
f"Unsupported quantization_spec {quantization_spec} for weight"