mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update get_aten_graph_module (#121937)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121937 Approved by: https://github.com/andrewor14
This commit is contained in:
committed by
PyTorch MergeBot
parent
af86d67d61
commit
53d2188df9
@ -309,7 +309,6 @@ coverage_ignore_functions = [
|
||||
"reference_representation_rewrite",
|
||||
# torch.ao.quantization.pt2e.utils
|
||||
"fold_bn_weights_into_conv_node",
|
||||
"get_aten_graph_module",
|
||||
"remove_tensor_overload_for_qdq_ops",
|
||||
# torch.ao.quantization.qconfig
|
||||
"get_default_qat_qconfig",
|
||||
|
@ -1519,7 +1519,6 @@
|
||||
"SharedQuantizationSpec",
|
||||
"Tuple",
|
||||
"fold_bn_weights_into_conv_node",
|
||||
"get_aten_graph_module",
|
||||
"replace_pattern_with_filters"
|
||||
],
|
||||
"torch.ao.quantization.quantize_fx": [
|
||||
|
@ -46,7 +46,7 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||
See https://github.com/pytorch/pytorch/issues/103681.
|
||||
"""
|
||||
# Avoid circular dependencies
|
||||
from .utils import get_aten_graph_module
|
||||
from .utils import _get_aten_graph_module_for_pattern
|
||||
|
||||
# Needed to ensure subgraph matches are self-contained
|
||||
m.graph.eliminate_dead_code()
|
||||
@ -62,17 +62,17 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
if train_to_eval:
|
||||
match_pattern = get_aten_graph_module(
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_train), example_inputs
|
||||
)
|
||||
replacement_pattern = get_aten_graph_module(
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_eval), example_inputs
|
||||
)
|
||||
else:
|
||||
match_pattern = get_aten_graph_module(
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_eval), example_inputs
|
||||
)
|
||||
replacement_pattern = get_aten_graph_module(
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(dropout_train), example_inputs
|
||||
)
|
||||
|
||||
@ -101,7 +101,7 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||
# Enable this support in future updates.
|
||||
|
||||
# Avoid circular dependencies
|
||||
from .utils import get_aten_graph_module
|
||||
from .utils import _get_aten_graph_module_for_pattern
|
||||
|
||||
# Needed to ensure subgraph matches are self-contained
|
||||
m.graph.eliminate_dead_code()
|
||||
@ -137,13 +137,17 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
||||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
if train_to_eval:
|
||||
match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs)
|
||||
replacement_pattern = get_aten_graph_module(
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(bn_train), example_inputs
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(bn_eval), example_inputs
|
||||
)
|
||||
else:
|
||||
match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs)
|
||||
replacement_pattern = get_aten_graph_module(
|
||||
match_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(bn_eval), example_inputs
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(
|
||||
_WrapperModule(bn_train), example_inputs
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,7 @@ from .utils import (
|
||||
_is_conv,
|
||||
_is_bn_node,
|
||||
fold_bn_weights_into_conv_node,
|
||||
get_aten_graph_module,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -546,7 +546,7 @@ def _fuse_conv_bn_qat_helper(
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
|
||||
match_pattern = get_aten_graph_module(conv_bn_pattern, example_inputs, is_cuda)
|
||||
match_pattern = _get_aten_graph_module_for_pattern(conv_bn_pattern, example_inputs, is_cuda)
|
||||
|
||||
# Step (1): Replace patterns with conv bias
|
||||
#
|
||||
@ -555,7 +555,7 @@ def _fuse_conv_bn_qat_helper(
|
||||
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
||||
|
||||
qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
|
||||
replacement_pattern_with_conv_bias = get_aten_graph_module(
|
||||
replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern(
|
||||
qat_conv_bn_pattern,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
@ -572,7 +572,7 @@ def _fuse_conv_bn_qat_helper(
|
||||
# Step (2): Replace patterns without conv bias
|
||||
|
||||
qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
|
||||
replacement_pattern_no_conv_bias = get_aten_graph_module(
|
||||
replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern(
|
||||
qat_conv_bn_pattern_no_conv_bias,
|
||||
example_inputs,
|
||||
is_cuda,
|
||||
@ -738,11 +738,11 @@ def _fold_conv_bn_qat_helper(
|
||||
match_pattern = _get_quantized_qat_conv_bn_pattern(
|
||||
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
|
||||
)
|
||||
match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
|
||||
match_pattern = _get_aten_graph_module_for_pattern(match_pattern, example_inputs, is_cuda, **kwargs)
|
||||
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
|
||||
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
|
||||
)
|
||||
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
|
||||
replacement_pattern = _get_aten_graph_module_for_pattern(replacement_pattern, example_inputs, is_cuda, **kwargs)
|
||||
replacements.extend(
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
from torch.fx import GraphModule
|
||||
from ..export_utils import _WrapperModule
|
||||
from ..utils import (
|
||||
get_aten_graph_module,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
remove_tensor_overload_for_qdq_ops,
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_existing_placeholders,
|
||||
@ -586,9 +586,9 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
replacement = rewrite_info.replacement
|
||||
pattern_post_trans = rewrite_info.pattern_post_trans
|
||||
replacement_post_trans = rewrite_info.replacement_post_trans
|
||||
pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
||||
replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
||||
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
||||
if pattern_post_trans:
|
||||
pattern = pattern_post_trans(pattern)
|
||||
|
@ -20,7 +20,7 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation
|
||||
|
||||
__all__ = [
|
||||
"fold_bn_weights_into_conv_node",
|
||||
"get_aten_graph_module",
|
||||
"_get_aten_graph_module_for_pattern",
|
||||
"remove_tensor_overload_for_qdq_ops",
|
||||
]
|
||||
|
||||
@ -292,7 +292,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
|
||||
node_name_to_scope[n.name] = current_scope
|
||||
return node_name_to_scope
|
||||
|
||||
def get_aten_graph_module(
|
||||
def _get_aten_graph_module_for_pattern(
|
||||
pattern: Callable,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
is_cuda: bool = False,
|
||||
@ -310,6 +310,16 @@ def get_aten_graph_module(
|
||||
)
|
||||
aten_pattern.graph.eliminate_dead_code()
|
||||
aten_pattern.recompile()
|
||||
|
||||
# ep.module() adds copy_ nodes for the mutated inputs.
|
||||
# For patterns, it doesn't matter
|
||||
for node in aten_pattern.graph.nodes:
|
||||
if node.op == "call_function" and node.target == torch.ops.aten.copy_.default and len(node.users) == 0:
|
||||
aten_pattern.graph.erase_node(node)
|
||||
|
||||
aten_pattern.graph.eliminate_dead_code()
|
||||
aten_pattern.recompile()
|
||||
|
||||
return aten_pattern
|
||||
|
||||
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||
@ -370,8 +380,8 @@ def _replace_literals_with_new_placeholders(
|
||||
return x - 3
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
pattern_gm = get_aten_graph_module(pattern, example_inputs)
|
||||
replacement_gm = get_aten_graph_module(pattern, example_inptus)
|
||||
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
|
||||
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
|
||||
|
||||
# 2. Before calling replace literals we'll see the following graph:
|
||||
def pattern(self, x):
|
||||
@ -456,8 +466,8 @@ def _replace_literals_with_existing_placeholders(
|
||||
-128,
|
||||
127,
|
||||
)
|
||||
pattern_gm = get_aten_graph_module(pattern, example_inputs)
|
||||
replacement_gm = get_aten_graph_module(pattern, example_inptus)
|
||||
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
|
||||
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
|
||||
|
||||
# 2. Before calling replace literals we'll see the following graph:
|
||||
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
|
||||
|
@ -12,7 +12,7 @@ from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
get_aten_graph_module,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
)
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
@ -469,7 +469,7 @@ def _do_annotate_conv_bn(
|
||||
# Match against all conv dimensions and cuda variants
|
||||
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
|
||||
pattern = get_pattern(conv_fn, relu_is_inplace)
|
||||
pattern = get_aten_graph_module(pattern, example_inputs, is_cuda)
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
|
||||
pattern.graph.eliminate_dead_code()
|
||||
pattern.recompile()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
||||
|
Reference in New Issue
Block a user