mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][be] Remove unused helper functions in convert.py (#86913)
Summary: att Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/86913 Approved by: https://github.com/vkuzo
This commit is contained in:
committed by
PyTorch MergeBot
parent
761ca20dd8
commit
b8aa1767cd
@ -77,14 +77,11 @@ __all__ = [
|
||||
"convert_custom_module",
|
||||
"convert_standalone_module",
|
||||
"convert_weighted_module",
|
||||
"duplicate_dequantize_node",
|
||||
"duplicate_quantize_dynamic_node",
|
||||
"get_module_path_and_prefix",
|
||||
"has_none_qconfig",
|
||||
"insert_dequantize_node",
|
||||
"maybe_get_observer_for_node",
|
||||
"maybe_recursive_remove_dequantize",
|
||||
"remove_extra_dequantize",
|
||||
"restore_state",
|
||||
"run_weight_observers",
|
||||
]
|
||||
@ -129,68 +126,6 @@ def run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -
|
||||
# run the weight observer
|
||||
weight_observer_module()
|
||||
|
||||
# this method is temporary will be removed soon
|
||||
def duplicate_quantize_dynamic_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
|
||||
quantized_root = quantized
|
||||
for node in quantized.graph.nodes:
|
||||
if (node.op == "call_function" and node.target == torch.quantize_per_tensor_dynamic):
|
||||
users = list(node.users)
|
||||
if len(users) > 1:
|
||||
for user in users:
|
||||
with quantized.graph.inserting_before(node):
|
||||
new_node = quantized.graph.create_node(
|
||||
"call_function",
|
||||
torch.quantize_per_tensor_dynamic,
|
||||
node.args,
|
||||
node.kwargs)
|
||||
user.replace_input_with(node, new_node)
|
||||
quantized.graph.erase_node(node)
|
||||
|
||||
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
|
||||
return quantized
|
||||
|
||||
def duplicate_dequantize_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
|
||||
"""
|
||||
If a dequantize node has multiple uses, duplicate it and create one dequantize node for each use.
|
||||
This is to enable the pattern matching to map from individual quant - dequant - ref_module to
|
||||
final quantized module.
|
||||
"""
|
||||
quantized_root = quantized
|
||||
for node in quantized.graph.nodes:
|
||||
if (node.op == "call_method" and node.target == "dequantize" or
|
||||
(node.op == "call_function" and node.target == torch.dequantize)):
|
||||
users = list(node.users)
|
||||
if len(users) > 1:
|
||||
for user in users:
|
||||
with quantized.graph.inserting_before(node):
|
||||
new_node = quantized.graph.create_node("call_method", "dequantize", node.args, {})
|
||||
user.replace_input_with(node, new_node)
|
||||
quantized.graph.erase_node(node)
|
||||
|
||||
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
|
||||
return quantized
|
||||
|
||||
def remove_extra_dequantize(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
|
||||
"""
|
||||
Removes duplicate dequant nodes in the graph, for an operator that has multiple dequant nodes as a user,
|
||||
replace them with a single dequant node that can be shared across all the uses.
|
||||
"""
|
||||
quantized_root = quantized
|
||||
for node in quantized.graph.nodes:
|
||||
users = list(node.users)
|
||||
dequant_users = [user for user in node.users if user.op == "call_method" and user.target == "dequantize" or
|
||||
(user.op == "call_function" and user.target == torch.dequantize)]
|
||||
|
||||
if len(dequant_users) > 1:
|
||||
with quantized.graph.inserting_after(node):
|
||||
unique_dq = quantized.graph.create_node("call_method", "dequantize", users[0].args, {})
|
||||
for dequant in dequant_users:
|
||||
dequant.replace_all_uses_with(unique_dq)
|
||||
quantized.graph.erase_node(dequant)
|
||||
|
||||
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
|
||||
return quantized
|
||||
|
||||
def maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
|
||||
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
|
||||
we'll recursively remove the dequantize Node
|
||||
|
Reference in New Issue
Block a user