ns for fx: add linear-relu mod weight extraction (#55080)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55080

Adds support for extracting weights of linear-relu module pattern.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
```

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D27474701

fbshipit-source-id: 69ceaadc28d7fdcebd16d519367274d348b0dd29
This commit is contained in:
Vasiliy Kuznetsov
2021-04-14 08:59:38 -07:00
committed by Facebook GitHub Bot
parent 2587a28bbd
commit 444b318a90
2 changed files with 13 additions and 1 deletions

View File

@ -426,6 +426,12 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
# conv3d - relu
self.conv3d_1 = nn.Conv3d(1, 1, 1)
self.relu_2 = nn.ReLU()
# linear
self.linear_0 = nn.Linear(1, 1)
# linear - relu
self.linear_1 = nn.Linear(1, 1)
self.relu_3 = nn.ReLU()
def forward(self, x):
x = self.conv1d_0(x)
@ -439,10 +445,14 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
x = self.conv3d_0(x)
x = self.conv3d_1(x)
x = self.relu_2(x)
x = x.reshape(1, 1)
x = self.linear_0(x)
x = self.linear_1(x)
x = self.relu_3(x)
return x
m = M().eval()
self._test_extract_weights(m, results_len=6)
self._test_extract_weights(m, results_len=8)
@skipIfNoFBGEMM
def test_extract_weights_fun(self):

View File

@ -53,6 +53,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[Callable]]:
'torch.nn.Linear': set([
nn.Linear,
nnq.Linear,
nniq.LinearReLU,
nnqat.Linear,
nnqd.Linear,
]),
@ -153,6 +154,7 @@ def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
((nn.ReLU, nn.Conv1d), 0),
((nn.ReLU, nn.Conv2d), 0),
((nn.ReLU, nn.Conv3d), 0),
((nn.ReLU, nn.Linear), 0),
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
((("to", torch.float16), F.relu, F.linear, "dequantize"), 1),