mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
ns for fx: add linear-relu mod weight extraction (#55080)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55080 Adds support for extracting weights of linear-relu module pattern. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs ``` Imported from OSS Reviewed By: raghuramank100 Differential Revision: D27474701 fbshipit-source-id: 69ceaadc28d7fdcebd16d519367274d348b0dd29
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2587a28bbd
commit
444b318a90
@ -426,6 +426,12 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
||||
# conv3d - relu
|
||||
self.conv3d_1 = nn.Conv3d(1, 1, 1)
|
||||
self.relu_2 = nn.ReLU()
|
||||
# linear
|
||||
self.linear_0 = nn.Linear(1, 1)
|
||||
# linear - relu
|
||||
self.linear_1 = nn.Linear(1, 1)
|
||||
self.relu_3 = nn.ReLU()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1d_0(x)
|
||||
@ -439,10 +445,14 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
||||
x = self.conv3d_0(x)
|
||||
x = self.conv3d_1(x)
|
||||
x = self.relu_2(x)
|
||||
x = x.reshape(1, 1)
|
||||
x = self.linear_0(x)
|
||||
x = self.linear_1(x)
|
||||
x = self.relu_3(x)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
self._test_extract_weights(m, results_len=6)
|
||||
self._test_extract_weights(m, results_len=8)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_extract_weights_fun(self):
|
||||
|
@ -53,6 +53,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[Callable]]:
|
||||
'torch.nn.Linear': set([
|
||||
nn.Linear,
|
||||
nnq.Linear,
|
||||
nniq.LinearReLU,
|
||||
nnqat.Linear,
|
||||
nnqd.Linear,
|
||||
]),
|
||||
@ -153,6 +154,7 @@ def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
|
||||
((nn.ReLU, nn.Conv1d), 0),
|
||||
((nn.ReLU, nn.Conv2d), 0),
|
||||
((nn.ReLU, nn.Conv3d), 0),
|
||||
((nn.ReLU, nn.Linear), 0),
|
||||
# linear-relu fp16 emulation:
|
||||
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
|
||||
((("to", torch.float16), F.relu, F.linear, "dequantize"), 1),
|
||||
|
Reference in New Issue
Block a user