mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quantization] add an option keep_original_weights in _lower_to_native_backend (#141049)
Differential Revision: D66153809 This diff adds an option to keep_original_weights so we can track back the original weight and bias after performing prepare_fx and convert_fx Pull Request resolved: https://github.com/pytorch/pytorch/pull/141049 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
809106a93f
commit
43853691bc
@ -6640,6 +6640,101 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
}
|
||||
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_keep_original_weights(self):
|
||||
class SubModule(nn.Module):
|
||||
"""
|
||||
A simple submodule containing a linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(__class__, self).__init__()
|
||||
self.w = nn.Parameter(torch.randn(input_dim, output_dim))
|
||||
self.b = nn.Parameter(torch.randn(input_dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.w, self.b)
|
||||
|
||||
class MainModule(nn.Module):
|
||||
"""
|
||||
The main module containing the submodule.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||
super(__class__, self).__init__()
|
||||
self.submodule_1 = SubModule(hidden_dim, input_dim)
|
||||
setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim))
|
||||
setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim))
|
||||
setattr(self, 'submodule:4', SubModule(hidden_dim, hidden_dim))
|
||||
self._w = nn.Parameter(torch.randn(output_dim, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.submodule_1(x)
|
||||
x2 = getattr(self, 'submodule|2')(x1)
|
||||
x3 = getattr(self, 'submodule/3')(x2)
|
||||
x4 = getattr(self, 'submodule:4')(x3)
|
||||
x5 = F.linear(x4, self._w)
|
||||
return x5
|
||||
|
||||
input_dim = 10
|
||||
hidden_dim = 20
|
||||
output_dim = 5
|
||||
model = MainModule(input_dim, hidden_dim, output_dim)
|
||||
model.eval()
|
||||
example_inputs = torch.randn(1, input_dim)
|
||||
_ = model(*example_inputs)
|
||||
qconfig_mapping = QConfigMapping().set_object_type(nn.functional.linear, float16_dynamic_qconfig)
|
||||
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
|
||||
prepared_model(example_inputs)
|
||||
quantized_model = convert_fx(prepared_model, keep_original_weights=True)
|
||||
|
||||
self.assertTrue(len(quantized_model.original_weights_lookup) == 5)
|
||||
self.assertTrue("submodule_1_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][0],
|
||||
model.submodule_1.w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][1],
|
||||
model.submodule_1.b
|
||||
)
|
||||
self.assertTrue("submodule_2_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][0],
|
||||
getattr(model, "submodule|2").w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][1],
|
||||
getattr(model, "submodule|2").b
|
||||
)
|
||||
self.assertTrue("submodule_3_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][0],
|
||||
getattr(model, "submodule/3").w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][1],
|
||||
getattr(model, "submodule/3").b
|
||||
)
|
||||
self.assertTrue("submodule_4_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][0],
|
||||
getattr(model, "submodule:4").w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][1],
|
||||
getattr(model, "submodule:4").b
|
||||
)
|
||||
self.assertTrue("_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["_packed_weight_0"][0],
|
||||
model._w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["_packed_weight_0"][1],
|
||||
None
|
||||
)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
def setUp(self):
|
||||
|
Reference in New Issue
Block a user