ns for fx: add linear-relu mod weight extraction (#55080)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55080

Adds support for extracting weights of linear-relu module pattern.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
```

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D27474701

fbshipit-source-id: 69ceaadc28d7fdcebd16d519367274d348b0dd29
This commit is contained in:
Vasiliy Kuznetsov
2021-04-14 08:59:38 -07:00
committed by Facebook GitHub Bot
parent 2587a28bbd
commit 444b318a90
2 changed files with 13 additions and 1 deletions

View File

@ -53,6 +53,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[Callable]]:
'torch.nn.Linear': set([
nn.Linear,
nnq.Linear,
nniq.LinearReLU,
nnqat.Linear,
nnqd.Linear,
]),
@ -153,6 +154,7 @@ def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
((nn.ReLU, nn.Conv1d), 0),
((nn.ReLU, nn.Conv2d), 0),
((nn.ReLU, nn.Conv3d), 0),
((nn.ReLU, nn.Linear), 0),
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
((("to", torch.float16), F.relu, F.linear, "dequantize"), 1),