[Quantization] add an option keep_original_weights in _lower_to_native_backend (#141049)

Differential Revision: D66153809

This diff adds an option to keep_original_weights so we can track back the original weight and bias after performing prepare_fx and convert_fx

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141049
Approved by: https://github.com/jerryzh168
This commit is contained in:
Huamin Li
2024-12-27 04:02:07 +00:00
committed by PyTorch MergeBot
parent 809106a93f
commit 43853691bc
5 changed files with 142 additions and 4 deletions

View File

@ -6640,6 +6640,101 @@ class TestQuantizeFx(QuantizationTestCase):
}
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
@skipIfNoFBGEMM
def test_keep_original_weights(self):
class SubModule(nn.Module):
"""
A simple submodule containing a linear layer.
"""
def __init__(self, input_dim, output_dim):
super(__class__, self).__init__()
self.w = nn.Parameter(torch.randn(input_dim, output_dim))
self.b = nn.Parameter(torch.randn(input_dim))
def forward(self, x):
return F.linear(x, self.w, self.b)
class MainModule(nn.Module):
"""
The main module containing the submodule.
"""
def __init__(self, input_dim, hidden_dim, output_dim):
super(__class__, self).__init__()
self.submodule_1 = SubModule(hidden_dim, input_dim)
setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim))
setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim))
setattr(self, 'submodule:4', SubModule(hidden_dim, hidden_dim))
self._w = nn.Parameter(torch.randn(output_dim, hidden_dim))
def forward(self, x):
x1 = self.submodule_1(x)
x2 = getattr(self, 'submodule|2')(x1)
x3 = getattr(self, 'submodule/3')(x2)
x4 = getattr(self, 'submodule:4')(x3)
x5 = F.linear(x4, self._w)
return x5
input_dim = 10
hidden_dim = 20
output_dim = 5
model = MainModule(input_dim, hidden_dim, output_dim)
model.eval()
example_inputs = torch.randn(1, input_dim)
_ = model(*example_inputs)
qconfig_mapping = QConfigMapping().set_object_type(nn.functional.linear, float16_dynamic_qconfig)
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
prepared_model(example_inputs)
quantized_model = convert_fx(prepared_model, keep_original_weights=True)
self.assertTrue(len(quantized_model.original_weights_lookup) == 5)
self.assertTrue("submodule_1_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][0],
model.submodule_1.w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][1],
model.submodule_1.b
)
self.assertTrue("submodule_2_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][0],
getattr(model, "submodule|2").w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][1],
getattr(model, "submodule|2").b
)
self.assertTrue("submodule_3_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][0],
getattr(model, "submodule/3").w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][1],
getattr(model, "submodule/3").b
)
self.assertTrue("submodule_4_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][0],
getattr(model, "submodule:4").w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][1],
getattr(model, "submodule:4").b
)
self.assertTrue("_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["_packed_weight_0"][0],
model._w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["_packed_weight_0"][1],
None
)
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
def setUp(self):