mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Quant][PT2E] Enable linear and linear-unary post-op gelu quant recipe for x86 inductor quantizer (#114853)
**Summary** Add Gelu for linear-unary post-op quantization recipe to x86 inductor quantizer. **Test plan** python -m pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_unary_gelu python test/test_quantization.py -k test_linear_unary_with_quantizer_api Co-authored-by: leslie-fang-intel <leslie.fang@intel.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/114853 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
a04e7fca8e
commit
25e00545bb
@ -251,10 +251,13 @@ class TestHelperModules:
|
|||||||
return self.linear(x)
|
return self.linear(x)
|
||||||
|
|
||||||
class LinearUnaryModule(torch.nn.Module):
|
class LinearUnaryModule(torch.nn.Module):
|
||||||
def __init__(self, use_bias, postop, inplace_postop) -> None:
|
def __init__(self, use_bias, postop, inplace_postop=False, post_op_algo='none') -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = nn.Linear(4, 4, bias=use_bias)
|
self.linear = nn.Linear(4, 4, bias=use_bias)
|
||||||
self.postop = postop(inplace=inplace_postop)
|
if postop == nn.GELU:
|
||||||
|
self.postop = postop(approximate=post_op_algo)
|
||||||
|
else:
|
||||||
|
self.postop = postop(inplace=inplace_postop)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.postop(self.linear(x))
|
return self.postop(self.linear(x))
|
||||||
@ -1010,6 +1013,44 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||||||
node_list,
|
node_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoX86
|
||||||
|
def test_linear_unary_gelu(self):
|
||||||
|
"""
|
||||||
|
Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer.
|
||||||
|
"""
|
||||||
|
use_bias_list = [True, False]
|
||||||
|
postop = nn.GELU
|
||||||
|
post_op_algorithm = ['none', 'tanh']
|
||||||
|
cases = itertools.product(use_bias_list, post_op_algorithm)
|
||||||
|
with override_quantized_engine("x86"), torch.no_grad():
|
||||||
|
for use_bias, post_op_algo in cases:
|
||||||
|
m = TestHelperModules.LinearUnaryModule(use_bias=use_bias, postop=postop, post_op_algo=post_op_algo).eval()
|
||||||
|
example_inputs = (torch.randn(2, 4),)
|
||||||
|
quantizer = X86InductorQuantizer().set_global(
|
||||||
|
xiq.get_default_x86_inductor_quantization_config()
|
||||||
|
)
|
||||||
|
node_occurrence = {
|
||||||
|
# one for input and weight of the conv, one for output for the gelu
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
|
||||||
|
# quantize_per_channel for weights are const propagated
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||||
|
}
|
||||||
|
node_list = [
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||||
|
torch.ops.aten.linear.default,
|
||||||
|
torch.ops.aten.gelu.default,
|
||||||
|
]
|
||||||
|
self._test_quantizer(
|
||||||
|
m,
|
||||||
|
example_inputs,
|
||||||
|
quantizer,
|
||||||
|
node_occurrence,
|
||||||
|
node_list,
|
||||||
|
)
|
||||||
|
|
||||||
@skipIfTorchDynamo("very slow")
|
@skipIfTorchDynamo("very slow")
|
||||||
@skipIfNoX86
|
@skipIfNoX86
|
||||||
def test_qat_conv2d(self):
|
def test_qat_conv2d(self):
|
||||||
|
@ -378,11 +378,13 @@ def _maybe_insert_input_observers_for_node(
|
|||||||
numeric_debug_handle = node.meta["numeric_debug_handle"]
|
numeric_debug_handle = node.meta["numeric_debug_handle"]
|
||||||
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
||||||
|
|
||||||
# Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
|
# Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
|
||||||
# that persist in exported graph. This is just a work around for these.
|
# gelu has a has an approximate kwarg that persist in exported graph.
|
||||||
|
# This is just a work around for these.
|
||||||
assert (
|
assert (
|
||||||
node.target == torch.ops.aten.clone.default or
|
node.target == torch.ops.aten.clone.default or
|
||||||
node.target == torch.ops.aten.zeros_like.default or
|
node.target == torch.ops.aten.zeros_like.default or
|
||||||
|
node.target == torch.ops.aten.gelu.default or
|
||||||
len(node.kwargs) == 0
|
len(node.kwargs) == 0
|
||||||
), " expecting kwargs for aten op IR to be empty"
|
), " expecting kwargs for aten op IR to be empty"
|
||||||
|
|
||||||
|
@ -985,6 +985,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
torch.nn.ReLU,
|
torch.nn.ReLU,
|
||||||
torch.nn.LeakyReLU,
|
torch.nn.LeakyReLU,
|
||||||
torch.nn.Tanh,
|
torch.nn.Tanh,
|
||||||
|
torch.nn.GELU,
|
||||||
]
|
]
|
||||||
fused_partitions: List[tuple] = []
|
fused_partitions: List[tuple] = []
|
||||||
for postop in postop_list:
|
for postop in postop_list:
|
||||||
|
Reference in New Issue
Block a user