mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
graph mode: add hardswish inplace handling (#40284)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40284 Adds graph mode handling for inplace hardswish, and test coverage for functional hardswish. Test Plan: ``` python test/test_quantization.py TestQuantizeScriptPTSQOps.test_hardswish ``` Imported from OSS Differential Revision: D22140628 fbshipit-source-id: 55a514f7dc1130d510f69ee4e611d7cb5e08d02e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c6dbfcaf9e
commit
ab8a99bd36
@ -2177,11 +2177,23 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||
.run(m.graph)
|
||||
|
||||
def test_hardswish(self):
|
||||
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
||||
hardswish = torch.nn.Hardswish()
|
||||
for tracing in [True, False]:
|
||||
m = self.checkGraphModeOp(hardswish, data, "quantized::hardswish", tracing)
|
||||
class FunctionalHardswish(torch.nn.Module):
|
||||
def __init__(self, inplace):
|
||||
super(FunctionalHardswish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, input):
|
||||
return torch.nn.functional.hardswish(input, inplace=self.inplace)
|
||||
|
||||
modules = [torch.nn.Hardswish(), FunctionalHardswish(True),
|
||||
FunctionalHardswish(False)]
|
||||
|
||||
for test_case in itertools.product([True, False], modules):
|
||||
tracing, m = test_case
|
||||
m = self.checkGraphModeOp(
|
||||
m, self.img_data, "quantized::hardswish", tracing)
|
||||
FileCheck().check_not("aten::hardswish") \
|
||||
.check_not("aten::hardswish_") \
|
||||
.run(m.graph)
|
||||
|
||||
def test_elu(self):
|
||||
|
@ -34,6 +34,7 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
|
||||
"addmm",
|
||||
"matmul",
|
||||
"hardswish",
|
||||
"hardswish_",
|
||||
"elu",
|
||||
"elu_",
|
||||
"batch_norm",
|
||||
|
@ -848,6 +848,9 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
|
||||
auto hardswish = getObservedQParamOpFusionInfo(
|
||||
"aten::hardswish", "quantized::hardswish", {}, {});
|
||||
|
||||
auto hardswish_ = getObservedQParamOpFusionInfo(
|
||||
"aten::hardswish_", "quantized::hardswish", {}, {});
|
||||
|
||||
auto layer_norm = getObservedQParamOpFusionInfo(
|
||||
"aten::layer_norm",
|
||||
"quantized::layer_norm",
|
||||
@ -968,6 +971,7 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
|
||||
{"quantized::mul", mul, quantized_mul},
|
||||
{"quantized::mul", inplace_mul, quantized_mul},
|
||||
hardswish,
|
||||
hardswish_,
|
||||
layer_norm,
|
||||
group_norm,
|
||||
instance_norm,
|
||||
|
@ -7,6 +7,7 @@ r"""Importing this file includes common utility methods and base clases for
|
||||
checking quantization api and properties of resulting modules.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import io
|
||||
import functools
|
||||
import torch
|
||||
@ -282,8 +283,12 @@ class QuantizationTestCase(TestCase):
|
||||
# make sure it runs
|
||||
outputs[d] = models[d](inputs)
|
||||
else:
|
||||
# module under test can contain in-place ops, and we depend on
|
||||
# input data staying constant for comparisons
|
||||
data_copy = copy.deepcopy(data)
|
||||
models[d] = quantize_jit(
|
||||
model, qconfig_dict, test_only_eval_fn, [data], inplace=False, debug=d)
|
||||
model, qconfig_dict, test_only_eval_fn, [data_copy], inplace=False,
|
||||
debug=d)
|
||||
# make sure it runs
|
||||
outputs[d] = models[d](*inputs)
|
||||
|
||||
|
Reference in New Issue
Block a user