[Quantization] add an option keep_original_weights in _lower_to_native_backend (#141049)

Differential Revision: D66153809

This diff adds an option to keep_original_weights so we can track back the original weight and bias after performing prepare_fx and convert_fx

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141049
Approved by: https://github.com/jerryzh168
This commit is contained in:
Huamin Li
2024-12-27 04:02:07 +00:00
committed by PyTorch MergeBot
parent 809106a93f
commit 43853691bc
5 changed files with 142 additions and 4 deletions

View File

@ -6640,6 +6640,101 @@ class TestQuantizeFx(QuantizationTestCase):
}
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
@skipIfNoFBGEMM
def test_keep_original_weights(self):
class SubModule(nn.Module):
"""
A simple submodule containing a linear layer.
"""
def __init__(self, input_dim, output_dim):
super(__class__, self).__init__()
self.w = nn.Parameter(torch.randn(input_dim, output_dim))
self.b = nn.Parameter(torch.randn(input_dim))
def forward(self, x):
return F.linear(x, self.w, self.b)
class MainModule(nn.Module):
"""
The main module containing the submodule.
"""
def __init__(self, input_dim, hidden_dim, output_dim):
super(__class__, self).__init__()
self.submodule_1 = SubModule(hidden_dim, input_dim)
setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim))
setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim))
setattr(self, 'submodule:4', SubModule(hidden_dim, hidden_dim))
self._w = nn.Parameter(torch.randn(output_dim, hidden_dim))
def forward(self, x):
x1 = self.submodule_1(x)
x2 = getattr(self, 'submodule|2')(x1)
x3 = getattr(self, 'submodule/3')(x2)
x4 = getattr(self, 'submodule:4')(x3)
x5 = F.linear(x4, self._w)
return x5
input_dim = 10
hidden_dim = 20
output_dim = 5
model = MainModule(input_dim, hidden_dim, output_dim)
model.eval()
example_inputs = torch.randn(1, input_dim)
_ = model(*example_inputs)
qconfig_mapping = QConfigMapping().set_object_type(nn.functional.linear, float16_dynamic_qconfig)
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
prepared_model(example_inputs)
quantized_model = convert_fx(prepared_model, keep_original_weights=True)
self.assertTrue(len(quantized_model.original_weights_lookup) == 5)
self.assertTrue("submodule_1_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][0],
model.submodule_1.w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][1],
model.submodule_1.b
)
self.assertTrue("submodule_2_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][0],
getattr(model, "submodule|2").w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][1],
getattr(model, "submodule|2").b
)
self.assertTrue("submodule_3_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][0],
getattr(model, "submodule/3").w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][1],
getattr(model, "submodule/3").b
)
self.assertTrue("submodule_4_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][0],
getattr(model, "submodule:4").w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][1],
getattr(model, "submodule:4").b
)
self.assertTrue("_packed_weight_0" in quantized_model.original_weights_lookup)
torch.testing.assert_close(
quantized_model.original_weights_lookup["_packed_weight_0"][0],
model._w
)
torch.testing.assert_close(
quantized_model.original_weights_lookup["_packed_weight_0"][1],
None
)
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
def setUp(self):

View File

@ -443,7 +443,9 @@ def _load_packed_weight(
def fold_weight(
quantized_model: GraphModule, node_name_to_scope: Dict[str, Tuple[str, type]]
quantized_model: GraphModule,
node_name_to_scope: Dict[str, Tuple[str, type]],
keep_original_weights: bool = False,
) -> GraphModule:
"""
Trace back from the weight node util we hit getattr, reconstruct the
@ -453,6 +455,8 @@ def fold_weight(
packed_weights = {}
# map from folded node name to the prepacked weight name
folded_nodes = {}
original_weights_lookup: Dict[str, List] = {}
lookup_counter = 0
# get packed weights
for node in quantized_model.graph.nodes:
if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS:
@ -466,6 +470,16 @@ def fold_weight(
)
packed_weight = prepacking_module()
packed_weights[node.name] = packed_weight
if keep_original_weights:
original_weights = list(prepacking_module.state_dict().values())
original_weights_lookup[str(lookup_counter)] = sorted(
original_weights, key=lambda x: x.numel(), reverse=True
)
if len(original_weights_lookup[str(lookup_counter)]) == 1:
# bias is None
original_weights_lookup[str(lookup_counter)].append(None)
lookup_counter += 1
lookup_counter = 0
# remove folded nodes and replace the prepacking node with getattr
folded_graph = Graph()
@ -490,6 +504,18 @@ def fold_weight(
env[node.name] = folded_graph.create_node(
"get_attr", packed_weight_name, (), {}
)
if keep_original_weights:
key_name = (
packed_weight_name.replace(":", "_")
.replace("/", "_")
.replace("|", "_")
.lower()
)
original_weights_lookup[key_name] = original_weights_lookup[
str(lookup_counter)
]
del original_weights_lookup[str(lookup_counter)]
lookup_counter += 1
elif prepack_node is not None:
# remove the foled node
continue
@ -500,6 +526,12 @@ def fold_weight(
quantized_model = GraphModule(quantized_model, folded_graph)
quantized_model._register_state_dict_hook(_save_packed_weight)
quantized_model.register_load_state_dict_pre_hook(_load_packed_weight)
if keep_original_weights:
setattr( # noqa: B010
quantized_model, "original_weights_lookup", original_weights_lookup
)
return quantized_model
@ -1296,6 +1328,7 @@ def _lower_to_native_backend(
model: GraphModule,
qconfig_map: Dict[str, QConfigAny],
node_name_to_scope: Dict[str, Tuple[str, type]],
keep_original_weights: bool = False,
) -> GraphModule:
"""Lower a quantized reference model (with reference quantized operator patterns)
to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
@ -1312,7 +1345,7 @@ def _lower_to_native_backend(
_lower_get_tensor_info_op(model)
special_pattern_replacement(model)
model.graph.eliminate_dead_code()
model = fold_weight(model, node_name_to_scope)
model = fold_weight(model, node_name_to_scope, keep_original_weights)
model.graph.eliminate_dead_code()
model.recompile()
model.graph.lint()

View File

@ -992,6 +992,7 @@ def convert(
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_decomposed: bool = False,
keep_original_weights: bool = False,
) -> GraphModule:
"""
We will convert an observed model (a module with observer calls) to a reference
@ -1243,7 +1244,9 @@ def convert(
# TODO: maybe move this to quantize_fx.py
if not is_reference:
model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
model = lower_to_fbgemm(
model, node_name_to_qconfig, node_name_to_scope, keep_original_weights
)
# TODO: this looks hacky, we want to check why we need this and see if we can
# remove this

View File

@ -13,8 +13,11 @@ def lower_to_fbgemm(
model: GraphModule,
qconfig_map: Dict[str, QConfigAny],
node_name_to_scope: Dict[str, Tuple[str, type]],
keep_original_weights: bool = False,
) -> GraphModule:
"""Lower a quantized reference model (with reference quantized operator patterns)
to fbgemm
"""
return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)
return _lower_to_native_backend(
model, qconfig_map, node_name_to_scope, keep_original_weights
)

View File

@ -515,6 +515,7 @@ def _convert_fx(
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_decomposed: bool = False,
keep_original_weights: bool = False,
) -> GraphModule:
"""`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`"""
if convert_custom_config is None:
@ -546,6 +547,7 @@ def _convert_fx(
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
is_decomposed=is_decomposed,
keep_original_weights=keep_original_weights,
)
attach_preserved_attrs_to_model(quantized, preserved_attrs)
@ -558,6 +560,7 @@ def convert_fx(
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
keep_original_weights: bool = False,
) -> GraphModule:
r"""Convert a calibrated or trained model to a quantized model
@ -616,6 +619,7 @@ def convert_fx(
_remove_qconfig=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config=backend_config,
keep_original_weights=keep_original_weights,
)