mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 17:54:55 +08:00
ns for fx: add linear-relu mod weight extraction (#55080)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55080 Adds support for extracting weights of linear-relu module pattern. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs ``` Imported from OSS Reviewed By: raghuramank100 Differential Revision: D27474701 fbshipit-source-id: 69ceaadc28d7fdcebd16d519367274d348b0dd29
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2587a28bbd
commit
444b318a90
@ -53,6 +53,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[Callable]]:
|
||||
'torch.nn.Linear': set([
|
||||
nn.Linear,
|
||||
nnq.Linear,
|
||||
nniq.LinearReLU,
|
||||
nnqat.Linear,
|
||||
nnqd.Linear,
|
||||
]),
|
||||
@ -153,6 +154,7 @@ def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
|
||||
((nn.ReLU, nn.Conv1d), 0),
|
||||
((nn.ReLU, nn.Conv2d), 0),
|
||||
((nn.ReLU, nn.Conv3d), 0),
|
||||
((nn.ReLU, nn.Linear), 0),
|
||||
# linear-relu fp16 emulation:
|
||||
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
|
||||
((("to", torch.float16), F.relu, F.linear, "dequantize"), 1),
|
||||
|
||||
Reference in New Issue
Block a user