[Quant][PT2E] Enable linear and linear-unary post-op gelu quant recipe for x86 inductor quantizer (#114853)

**Summary**
Add Gelu for linear-unary post-op quantization recipe to x86 inductor quantizer.

**Test plan**
python -m pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_unary_gelu
python test/test_quantization.py -k test_linear_unary_with_quantizer_api
Co-authored-by: leslie-fang-intel <leslie.fang@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114853
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
Le-Zheng
2024-03-13 04:19:00 -07:00
committed by PyTorch MergeBot
parent a04e7fca8e
commit 25e00545bb
3 changed files with 48 additions and 4 deletions

View File

@ -251,10 +251,13 @@ class TestHelperModules:
return self.linear(x)
class LinearUnaryModule(torch.nn.Module):
def __init__(self, use_bias, postop, inplace_postop) -> None:
def __init__(self, use_bias, postop, inplace_postop=False, post_op_algo='none') -> None:
super().__init__()
self.linear = nn.Linear(4, 4, bias=use_bias)
self.postop = postop(inplace=inplace_postop)
if postop == nn.GELU:
self.postop = postop(approximate=post_op_algo)
else:
self.postop = postop(inplace=inplace_postop)
def forward(self, x):
return self.postop(self.linear(x))
@ -1010,6 +1013,44 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_list,
)
@skipIfNoX86
def test_linear_unary_gelu(self):
"""
Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer.
"""
use_bias_list = [True, False]
postop = nn.GELU
post_op_algorithm = ['none', 'tanh']
cases = itertools.product(use_bias_list, post_op_algorithm)
with override_quantized_engine("x86"), torch.no_grad():
for use_bias, post_op_algo in cases:
m = TestHelperModules.LinearUnaryModule(use_bias=use_bias, postop=postop, post_op_algo=post_op_algo).eval()
example_inputs = (torch.randn(2, 4),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
node_occurrence = {
# one for input and weight of the conv, one for output for the gelu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.linear.default,
torch.ops.aten.gelu.default,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
@skipIfTorchDynamo("very slow")
@skipIfNoX86
def test_qat_conv2d(self):

View File

@ -378,11 +378,13 @@ def _maybe_insert_input_observers_for_node(
numeric_debug_handle = node.meta["numeric_debug_handle"]
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
# Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
# that persist in exported graph. This is just a work around for these.
# Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
# gelu has a has an approximate kwarg that persist in exported graph.
# This is just a work around for these.
assert (
node.target == torch.ops.aten.clone.default or
node.target == torch.ops.aten.zeros_like.default or
node.target == torch.ops.aten.gelu.default or
len(node.kwargs) == 0
), " expecting kwargs for aten op IR to be empty"

View File

@ -985,6 +985,7 @@ class X86InductorQuantizer(Quantizer):
torch.nn.ReLU,
torch.nn.LeakyReLU,
torch.nn.Tanh,
torch.nn.GELU,
]
fused_partitions: List[tuple] = []
for postop in postop_list: