[quant][be] Remove unused helper functions in convert.py (#86913)

Summary:
att

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86913
Approved by: https://github.com/vkuzo
This commit is contained in:
Jerry Zhang
2022-10-13 17:02:32 -07:00
committed by PyTorch MergeBot
parent 761ca20dd8
commit b8aa1767cd

View File

@ -77,14 +77,11 @@ __all__ = [
"convert_custom_module",
"convert_standalone_module",
"convert_weighted_module",
"duplicate_dequantize_node",
"duplicate_quantize_dynamic_node",
"get_module_path_and_prefix",
"has_none_qconfig",
"insert_dequantize_node",
"maybe_get_observer_for_node",
"maybe_recursive_remove_dequantize",
"remove_extra_dequantize",
"restore_state",
"run_weight_observers",
]
@ -129,68 +126,6 @@ def run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -
# run the weight observer
weight_observer_module()
# this method is temporary will be removed soon
def duplicate_quantize_dynamic_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
quantized_root = quantized
for node in quantized.graph.nodes:
if (node.op == "call_function" and node.target == torch.quantize_per_tensor_dynamic):
users = list(node.users)
if len(users) > 1:
for user in users:
with quantized.graph.inserting_before(node):
new_node = quantized.graph.create_node(
"call_function",
torch.quantize_per_tensor_dynamic,
node.args,
node.kwargs)
user.replace_input_with(node, new_node)
quantized.graph.erase_node(node)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def duplicate_dequantize_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
"""
If a dequantize node has multiple uses, duplicate it and create one dequantize node for each use.
This is to enable the pattern matching to map from individual quant - dequant - ref_module to
final quantized module.
"""
quantized_root = quantized
for node in quantized.graph.nodes:
if (node.op == "call_method" and node.target == "dequantize" or
(node.op == "call_function" and node.target == torch.dequantize)):
users = list(node.users)
if len(users) > 1:
for user in users:
with quantized.graph.inserting_before(node):
new_node = quantized.graph.create_node("call_method", "dequantize", node.args, {})
user.replace_input_with(node, new_node)
quantized.graph.erase_node(node)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def remove_extra_dequantize(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
"""
Removes duplicate dequant nodes in the graph, for an operator that has multiple dequant nodes as a user,
replace them with a single dequant node that can be shared across all the uses.
"""
quantized_root = quantized
for node in quantized.graph.nodes:
users = list(node.users)
dequant_users = [user for user in node.users if user.op == "call_method" and user.target == "dequantize" or
(user.op == "call_function" and user.target == torch.dequantize)]
if len(dequant_users) > 1:
with quantized.graph.inserting_after(node):
unique_dq = quantized.graph.create_node("call_method", "dequantize", users[0].args, {})
for dequant in dequant_users:
dequant.replace_all_uses_with(unique_dq)
quantized.graph.erase_node(dequant)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
we'll recursively remove the dequantize Node