mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][PT2E] Enable linear and linear-unary post-op gelu quant recipe for x86 inductor quantizer (#114853)
**Summary** Add Gelu for linear-unary post-op quantization recipe to x86 inductor quantizer. **Test plan** python -m pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_unary_gelu python test/test_quantization.py -k test_linear_unary_with_quantizer_api Co-authored-by: leslie-fang-intel <leslie.fang@intel.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/114853 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
a04e7fca8e
commit
25e00545bb
@ -251,10 +251,13 @@ class TestHelperModules:
|
||||
return self.linear(x)
|
||||
|
||||
class LinearUnaryModule(torch.nn.Module):
|
||||
def __init__(self, use_bias, postop, inplace_postop) -> None:
|
||||
def __init__(self, use_bias, postop, inplace_postop=False, post_op_algo='none') -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(4, 4, bias=use_bias)
|
||||
self.postop = postop(inplace=inplace_postop)
|
||||
if postop == nn.GELU:
|
||||
self.postop = postop(approximate=post_op_algo)
|
||||
else:
|
||||
self.postop = postop(inplace=inplace_postop)
|
||||
|
||||
def forward(self, x):
|
||||
return self.postop(self.linear(x))
|
||||
@ -1010,6 +1013,44 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_linear_unary_gelu(self):
|
||||
"""
|
||||
Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer.
|
||||
"""
|
||||
use_bias_list = [True, False]
|
||||
postop = nn.GELU
|
||||
post_op_algorithm = ['none', 'tanh']
|
||||
cases = itertools.product(use_bias_list, post_op_algorithm)
|
||||
with override_quantized_engine("x86"), torch.no_grad():
|
||||
for use_bias, post_op_algo in cases:
|
||||
m = TestHelperModules.LinearUnaryModule(use_bias=use_bias, postop=postop, post_op_algo=post_op_algo).eval()
|
||||
example_inputs = (torch.randn(2, 4),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config()
|
||||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the gelu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||
# quantize_per_channel for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.gelu.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("very slow")
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d(self):
|
||||
|
@ -378,11 +378,13 @@ def _maybe_insert_input_observers_for_node(
|
||||
numeric_debug_handle = node.meta["numeric_debug_handle"]
|
||||
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
||||
|
||||
# Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
|
||||
# that persist in exported graph. This is just a work around for these.
|
||||
# Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
|
||||
# gelu has a has an approximate kwarg that persist in exported graph.
|
||||
# This is just a work around for these.
|
||||
assert (
|
||||
node.target == torch.ops.aten.clone.default or
|
||||
node.target == torch.ops.aten.zeros_like.default or
|
||||
node.target == torch.ops.aten.gelu.default or
|
||||
len(node.kwargs) == 0
|
||||
), " expecting kwargs for aten op IR to be empty"
|
||||
|
||||
|
@ -985,6 +985,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
torch.nn.ReLU,
|
||||
torch.nn.LeakyReLU,
|
||||
torch.nn.Tanh,
|
||||
torch.nn.GELU,
|
||||
]
|
||||
fused_partitions: List[tuple] = []
|
||||
for postop in postop_list:
|
||||
|
Reference in New Issue
Block a user