mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quantization] add an option keep_original_weights in _lower_to_native_backend (#141049)
Differential Revision: D66153809 This diff adds an option to keep_original_weights so we can track back the original weight and bias after performing prepare_fx and convert_fx Pull Request resolved: https://github.com/pytorch/pytorch/pull/141049 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
809106a93f
commit
43853691bc
@ -6640,6 +6640,101 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||
}
|
||||
self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_keep_original_weights(self):
|
||||
class SubModule(nn.Module):
|
||||
"""
|
||||
A simple submodule containing a linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(__class__, self).__init__()
|
||||
self.w = nn.Parameter(torch.randn(input_dim, output_dim))
|
||||
self.b = nn.Parameter(torch.randn(input_dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.w, self.b)
|
||||
|
||||
class MainModule(nn.Module):
|
||||
"""
|
||||
The main module containing the submodule.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||
super(__class__, self).__init__()
|
||||
self.submodule_1 = SubModule(hidden_dim, input_dim)
|
||||
setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim))
|
||||
setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim))
|
||||
setattr(self, 'submodule:4', SubModule(hidden_dim, hidden_dim))
|
||||
self._w = nn.Parameter(torch.randn(output_dim, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.submodule_1(x)
|
||||
x2 = getattr(self, 'submodule|2')(x1)
|
||||
x3 = getattr(self, 'submodule/3')(x2)
|
||||
x4 = getattr(self, 'submodule:4')(x3)
|
||||
x5 = F.linear(x4, self._w)
|
||||
return x5
|
||||
|
||||
input_dim = 10
|
||||
hidden_dim = 20
|
||||
output_dim = 5
|
||||
model = MainModule(input_dim, hidden_dim, output_dim)
|
||||
model.eval()
|
||||
example_inputs = torch.randn(1, input_dim)
|
||||
_ = model(*example_inputs)
|
||||
qconfig_mapping = QConfigMapping().set_object_type(nn.functional.linear, float16_dynamic_qconfig)
|
||||
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
|
||||
prepared_model(example_inputs)
|
||||
quantized_model = convert_fx(prepared_model, keep_original_weights=True)
|
||||
|
||||
self.assertTrue(len(quantized_model.original_weights_lookup) == 5)
|
||||
self.assertTrue("submodule_1_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][0],
|
||||
model.submodule_1.w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][1],
|
||||
model.submodule_1.b
|
||||
)
|
||||
self.assertTrue("submodule_2_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][0],
|
||||
getattr(model, "submodule|2").w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][1],
|
||||
getattr(model, "submodule|2").b
|
||||
)
|
||||
self.assertTrue("submodule_3_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][0],
|
||||
getattr(model, "submodule/3").w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][1],
|
||||
getattr(model, "submodule/3").b
|
||||
)
|
||||
self.assertTrue("submodule_4_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][0],
|
||||
getattr(model, "submodule:4").w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][1],
|
||||
getattr(model, "submodule:4").b
|
||||
)
|
||||
self.assertTrue("_packed_weight_0" in quantized_model.original_weights_lookup)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["_packed_weight_0"][0],
|
||||
model._w
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
quantized_model.original_weights_lookup["_packed_weight_0"][1],
|
||||
None
|
||||
)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
def setUp(self):
|
||||
|
@ -443,7 +443,9 @@ def _load_packed_weight(
|
||||
|
||||
|
||||
def fold_weight(
|
||||
quantized_model: GraphModule, node_name_to_scope: Dict[str, Tuple[str, type]]
|
||||
quantized_model: GraphModule,
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
keep_original_weights: bool = False,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Trace back from the weight node util we hit getattr, reconstruct the
|
||||
@ -453,6 +455,8 @@ def fold_weight(
|
||||
packed_weights = {}
|
||||
# map from folded node name to the prepacked weight name
|
||||
folded_nodes = {}
|
||||
original_weights_lookup: Dict[str, List] = {}
|
||||
lookup_counter = 0
|
||||
# get packed weights
|
||||
for node in quantized_model.graph.nodes:
|
||||
if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS:
|
||||
@ -466,6 +470,16 @@ def fold_weight(
|
||||
)
|
||||
packed_weight = prepacking_module()
|
||||
packed_weights[node.name] = packed_weight
|
||||
if keep_original_weights:
|
||||
original_weights = list(prepacking_module.state_dict().values())
|
||||
original_weights_lookup[str(lookup_counter)] = sorted(
|
||||
original_weights, key=lambda x: x.numel(), reverse=True
|
||||
)
|
||||
if len(original_weights_lookup[str(lookup_counter)]) == 1:
|
||||
# bias is None
|
||||
original_weights_lookup[str(lookup_counter)].append(None)
|
||||
lookup_counter += 1
|
||||
lookup_counter = 0
|
||||
|
||||
# remove folded nodes and replace the prepacking node with getattr
|
||||
folded_graph = Graph()
|
||||
@ -490,6 +504,18 @@ def fold_weight(
|
||||
env[node.name] = folded_graph.create_node(
|
||||
"get_attr", packed_weight_name, (), {}
|
||||
)
|
||||
if keep_original_weights:
|
||||
key_name = (
|
||||
packed_weight_name.replace(":", "_")
|
||||
.replace("/", "_")
|
||||
.replace("|", "_")
|
||||
.lower()
|
||||
)
|
||||
original_weights_lookup[key_name] = original_weights_lookup[
|
||||
str(lookup_counter)
|
||||
]
|
||||
del original_weights_lookup[str(lookup_counter)]
|
||||
lookup_counter += 1
|
||||
elif prepack_node is not None:
|
||||
# remove the foled node
|
||||
continue
|
||||
@ -500,6 +526,12 @@ def fold_weight(
|
||||
quantized_model = GraphModule(quantized_model, folded_graph)
|
||||
quantized_model._register_state_dict_hook(_save_packed_weight)
|
||||
quantized_model.register_load_state_dict_pre_hook(_load_packed_weight)
|
||||
|
||||
if keep_original_weights:
|
||||
setattr( # noqa: B010
|
||||
quantized_model, "original_weights_lookup", original_weights_lookup
|
||||
)
|
||||
|
||||
return quantized_model
|
||||
|
||||
|
||||
@ -1296,6 +1328,7 @@ def _lower_to_native_backend(
|
||||
model: GraphModule,
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
keep_original_weights: bool = False,
|
||||
) -> GraphModule:
|
||||
"""Lower a quantized reference model (with reference quantized operator patterns)
|
||||
to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
|
||||
@ -1312,7 +1345,7 @@ def _lower_to_native_backend(
|
||||
_lower_get_tensor_info_op(model)
|
||||
special_pattern_replacement(model)
|
||||
model.graph.eliminate_dead_code()
|
||||
model = fold_weight(model, node_name_to_scope)
|
||||
model = fold_weight(model, node_name_to_scope, keep_original_weights)
|
||||
model.graph.eliminate_dead_code()
|
||||
model.recompile()
|
||||
model.graph.lint()
|
||||
|
@ -992,6 +992,7 @@ def convert(
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
is_decomposed: bool = False,
|
||||
keep_original_weights: bool = False,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
We will convert an observed model (a module with observer calls) to a reference
|
||||
@ -1243,7 +1244,9 @@ def convert(
|
||||
|
||||
# TODO: maybe move this to quantize_fx.py
|
||||
if not is_reference:
|
||||
model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
|
||||
model = lower_to_fbgemm(
|
||||
model, node_name_to_qconfig, node_name_to_scope, keep_original_weights
|
||||
)
|
||||
|
||||
# TODO: this looks hacky, we want to check why we need this and see if we can
|
||||
# remove this
|
||||
|
@ -13,8 +13,11 @@ def lower_to_fbgemm(
|
||||
model: GraphModule,
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
keep_original_weights: bool = False,
|
||||
) -> GraphModule:
|
||||
"""Lower a quantized reference model (with reference quantized operator patterns)
|
||||
to fbgemm
|
||||
"""
|
||||
return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)
|
||||
return _lower_to_native_backend(
|
||||
model, qconfig_map, node_name_to_scope, keep_original_weights
|
||||
)
|
||||
|
@ -515,6 +515,7 @@ def _convert_fx(
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
is_decomposed: bool = False,
|
||||
keep_original_weights: bool = False,
|
||||
) -> GraphModule:
|
||||
"""`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`"""
|
||||
if convert_custom_config is None:
|
||||
@ -546,6 +547,7 @@ def _convert_fx(
|
||||
qconfig_mapping=qconfig_mapping,
|
||||
backend_config=backend_config,
|
||||
is_decomposed=is_decomposed,
|
||||
keep_original_weights=keep_original_weights,
|
||||
)
|
||||
|
||||
attach_preserved_attrs_to_model(quantized, preserved_attrs)
|
||||
@ -558,6 +560,7 @@ def convert_fx(
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
keep_original_weights: bool = False,
|
||||
) -> GraphModule:
|
||||
r"""Convert a calibrated or trained model to a quantized model
|
||||
|
||||
@ -616,6 +619,7 @@ def convert_fx(
|
||||
_remove_qconfig=_remove_qconfig,
|
||||
qconfig_mapping=qconfig_mapping,
|
||||
backend_config=backend_config,
|
||||
keep_original_weights=keep_original_weights,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user