mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][PT2E][X86] annotate and convert for linear_dynamic_fp16 (#141480)
Annotate linear node for `linear_dynamic_fp16` with `X86InductorQuantizer` After `convert_pt2e`, the pattern will be ``` x | linear <- to_fp32 <- to_fp16 <- w ``` **Test plan** ``` pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_dynamic_fp16 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141480 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
b7a45dbae3
commit
9827d677b4
@ -1804,6 +1804,41 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
}
|
||||
self._check_annotation_stat(fq_m, expected_annotation_stat)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_dynamic_fp16(self):
|
||||
"""
|
||||
Test pattern of linear_dynamic_fp16.
|
||||
"""
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for use_bias in [True, False]:
|
||||
m = TestHelperModules.SingleLinearModule(use_bias).eval()
|
||||
example_inputs = (torch.randn(2, 4),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
quantizer.set_module_type_qconfig(
|
||||
torch.nn.Linear, xiq.get_x86_inductor_linear_dynamic_fp16_config()
|
||||
)
|
||||
node_occurrence = {
|
||||
# 2 convert_element_type nodes are inserted for weight
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.convert_element_type.no_fuse: 2,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.convert_element_type.no_fuse,
|
||||
torch.ops.aten.linear.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("very slow")
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d(self):
|
||||
|
@ -1185,3 +1185,22 @@ def fake_quant_per_channel_meta(
|
||||
quant_max: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor"
|
||||
)
|
||||
|
||||
|
||||
@impl(
|
||||
quantized_decomposed_lib,
|
||||
"convert_element_type.no_fuse",
|
||||
"CompositeExplicitAutograd",
|
||||
)
|
||||
def convert_element_type(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.ops.prims.convert_element_type.default(input, dtype)
|
||||
|
||||
|
||||
@impl(quantized_decomposed_lib, "convert_element_type.no_fuse", "Meta")
|
||||
def convert_element_type_meta(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.empty_like(input, dtype=dtype)
|
||||
|
@ -342,7 +342,18 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
||||
]
|
||||
graph.erase_node(node)
|
||||
elif dtype == torch.float16:
|
||||
raise NotImplementedError("decomposed to float16 op not implemented yet")
|
||||
# Insert to_fp16 -> to_fp32 node
|
||||
dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
|
||||
with graph.inserting_before(node):
|
||||
input_node = node.args[0]
|
||||
convert_fp16_node = graph.create_node(
|
||||
"call_function", dtype_convert_op, (input_node, torch.float16), {}
|
||||
)
|
||||
convert_fp32_node = graph.create_node(
|
||||
"call_function", dtype_convert_op, (convert_fp16_node, torch.float), {}
|
||||
)
|
||||
node.replace_all_uses_with(convert_fp32_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
# should not reach since we have checks in the beginning to make sure the
|
||||
# activation_post_process is supported
|
||||
|
@ -62,6 +62,7 @@ if TYPE_CHECKING:
|
||||
__all__ = [
|
||||
"X86InductorQuantizer",
|
||||
"get_default_x86_inductor_quantization_config",
|
||||
"get_x86_inductor_linear_dynamic_fp16_config",
|
||||
]
|
||||
|
||||
|
||||
@ -341,6 +342,25 @@ def get_default_x86_inductor_quantization_config(
|
||||
return quantization_config
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_x86_inductor_linear_dynamic_fp16_config():
|
||||
"""
|
||||
For linear_dynamic_fp16. The name may be confusing.
|
||||
The op's behavior is fp32_input * (fp16_weight -> to_fp32) -> fp32_output.
|
||||
"""
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.float16,
|
||||
observer_or_fake_quant_ctr=PlaceholderObserver,
|
||||
)
|
||||
quantization_config = QuantizationConfig(
|
||||
None, # input_quantization_spec
|
||||
None, # output_quantization_spec
|
||||
weight_quantization_spec,
|
||||
None, # bias_quantization_spec
|
||||
)
|
||||
return quantization_config
|
||||
|
||||
|
||||
def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None:
|
||||
"""Annotate nodes to exclude them from quantization (their `quantization_config` is `None`)."""
|
||||
if not isinstance(nodes, list):
|
||||
|
@ -152,6 +152,7 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_spec.qscheme not in [
|
||||
torch.per_tensor_symmetric,
|
||||
torch.per_channel_symmetric,
|
||||
None,
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization_spec {quantization_spec} for weight"
|
||||
|
Reference in New Issue
Block a user