graph mode: add hardswish inplace handling (#40284)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40284

Adds graph mode handling for inplace hardswish, and test coverage for functional hardswish.

Test Plan:
```
python test/test_quantization.py TestQuantizeScriptPTSQOps.test_hardswish
```

Imported from OSS

Differential Revision: D22140628

fbshipit-source-id: 55a514f7dc1130d510f69ee4e611d7cb5e08d02e
This commit is contained in:
Vasiliy Kuznetsov
2020-06-21 09:35:44 -07:00
committed by Facebook GitHub Bot
parent c6dbfcaf9e
commit ab8a99bd36
4 changed files with 27 additions and 5 deletions

View File

@ -2177,11 +2177,23 @@ class TestQuantizeJitOps(QuantizationTestCase):
.run(m.graph)
def test_hardswish(self):
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
hardswish = torch.nn.Hardswish()
for tracing in [True, False]:
m = self.checkGraphModeOp(hardswish, data, "quantized::hardswish", tracing)
class FunctionalHardswish(torch.nn.Module):
def __init__(self, inplace):
super(FunctionalHardswish, self).__init__()
self.inplace = inplace
def forward(self, input):
return torch.nn.functional.hardswish(input, inplace=self.inplace)
modules = [torch.nn.Hardswish(), FunctionalHardswish(True),
FunctionalHardswish(False)]
for test_case in itertools.product([True, False], modules):
tracing, m = test_case
m = self.checkGraphModeOp(
m, self.img_data, "quantized::hardswish", tracing)
FileCheck().check_not("aten::hardswish") \
.check_not("aten::hardswish_") \
.run(m.graph)
def test_elu(self):

View File

@ -34,6 +34,7 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
"addmm",
"matmul",
"hardswish",
"hardswish_",
"elu",
"elu_",
"batch_norm",

View File

@ -848,6 +848,9 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
auto hardswish = getObservedQParamOpFusionInfo(
"aten::hardswish", "quantized::hardswish", {}, {});
auto hardswish_ = getObservedQParamOpFusionInfo(
"aten::hardswish_", "quantized::hardswish", {}, {});
auto layer_norm = getObservedQParamOpFusionInfo(
"aten::layer_norm",
"quantized::layer_norm",
@ -968,6 +971,7 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
{"quantized::mul", mul, quantized_mul},
{"quantized::mul", inplace_mul, quantized_mul},
hardswish,
hardswish_,
layer_norm,
group_norm,
instance_norm,

View File

@ -7,6 +7,7 @@ r"""Importing this file includes common utility methods and base clases for
checking quantization api and properties of resulting modules.
"""
import copy
import io
import functools
import torch
@ -282,8 +283,12 @@ class QuantizationTestCase(TestCase):
# make sure it runs
outputs[d] = models[d](inputs)
else:
# module under test can contain in-place ops, and we depend on
# input data staying constant for comparisons
data_copy = copy.deepcopy(data)
models[d] = quantize_jit(
model, qconfig_dict, test_only_eval_fn, [data], inplace=False, debug=d)
model, qconfig_dict, test_only_eval_fn, [data_copy], inplace=False,
debug=d)
# make sure it runs
outputs[d] = models[d](*inputs)