Update get_aten_graph_module (#121937)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121937
Approved by: https://github.com/andrewor14
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-03-14 16:55:18 -07:00
committed by PyTorch MergeBot
parent af86d67d61
commit 53d2188df9
7 changed files with 41 additions and 29 deletions

View File

@ -309,7 +309,6 @@ coverage_ignore_functions = [
"reference_representation_rewrite",
# torch.ao.quantization.pt2e.utils
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"remove_tensor_overload_for_qdq_ops",
# torch.ao.quantization.qconfig
"get_default_qat_qconfig",

View File

@ -1519,7 +1519,6 @@
"SharedQuantizationSpec",
"Tuple",
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"replace_pattern_with_filters"
],
"torch.ao.quantization.quantize_fx": [

View File

@ -46,7 +46,7 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
See https://github.com/pytorch/pytorch/issues/103681.
"""
# Avoid circular dependencies
from .utils import get_aten_graph_module
from .utils import _get_aten_graph_module_for_pattern
# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
@ -62,17 +62,17 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
example_inputs = (torch.randn(1),)
if train_to_eval:
match_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train), example_inputs
)
replacement_pattern = get_aten_graph_module(
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval), example_inputs
)
else:
match_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval), example_inputs
)
replacement_pattern = get_aten_graph_module(
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train), example_inputs
)
@ -101,7 +101,7 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
# Enable this support in future updates.
# Avoid circular dependencies
from .utils import get_aten_graph_module
from .utils import _get_aten_graph_module_for_pattern
# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
@ -137,13 +137,17 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
torch.randn(1), # bn_running_var
)
if train_to_eval:
match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs)
replacement_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train), example_inputs
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval), example_inputs
)
else:
match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs)
replacement_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval), example_inputs
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train), example_inputs
)

View File

@ -24,7 +24,7 @@ from .utils import (
_is_conv,
_is_bn_node,
fold_bn_weights_into_conv_node,
get_aten_graph_module,
_get_aten_graph_module_for_pattern,
)
if TYPE_CHECKING:
@ -546,7 +546,7 @@ def _fuse_conv_bn_qat_helper(
m.graph.eliminate_dead_code()
m.recompile()
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
match_pattern = get_aten_graph_module(conv_bn_pattern, example_inputs, is_cuda)
match_pattern = _get_aten_graph_module_for_pattern(conv_bn_pattern, example_inputs, is_cuda)
# Step (1): Replace patterns with conv bias
#
@ -555,7 +555,7 @@ def _fuse_conv_bn_qat_helper(
# TODO: use the public replace_pattern API once it also returns replacement nodes
qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
replacement_pattern_with_conv_bias = get_aten_graph_module(
replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern(
qat_conv_bn_pattern,
example_inputs,
is_cuda,
@ -572,7 +572,7 @@ def _fuse_conv_bn_qat_helper(
# Step (2): Replace patterns without conv bias
qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
replacement_pattern_no_conv_bias = get_aten_graph_module(
replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern(
qat_conv_bn_pattern_no_conv_bias,
example_inputs,
is_cuda,
@ -738,11 +738,11 @@ def _fold_conv_bn_qat_helper(
match_pattern = _get_quantized_qat_conv_bn_pattern(
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
)
match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
match_pattern = _get_aten_graph_module_for_pattern(match_pattern, example_inputs, is_cuda, **kwargs)
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
)
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
replacement_pattern = _get_aten_graph_module_for_pattern(replacement_pattern, example_inputs, is_cuda, **kwargs)
replacements.extend(
replace_pattern_with_filters(
m,

View File

@ -2,7 +2,7 @@ import torch
from torch.fx import GraphModule
from ..export_utils import _WrapperModule
from ..utils import (
get_aten_graph_module,
_get_aten_graph_module_for_pattern,
remove_tensor_overload_for_qdq_ops,
_replace_literals_with_new_placeholders,
_replace_literals_with_existing_placeholders,
@ -586,9 +586,9 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
replacement = rewrite_info.replacement
pattern_post_trans = rewrite_info.pattern_post_trans
replacement_post_trans = rewrite_info.replacement_post_trans
pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
if pattern_post_trans:
pattern = pattern_post_trans(pattern)

View File

@ -20,7 +20,7 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation
__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"_get_aten_graph_module_for_pattern",
"remove_tensor_overload_for_qdq_ops",
]
@ -292,7 +292,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
node_name_to_scope[n.name] = current_scope
return node_name_to_scope
def get_aten_graph_module(
def _get_aten_graph_module_for_pattern(
pattern: Callable,
example_inputs: Tuple[Any, ...],
is_cuda: bool = False,
@ -310,6 +310,16 @@ def get_aten_graph_module(
)
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
# ep.module() adds copy_ nodes for the mutated inputs.
# For patterns, it doesn't matter
for node in aten_pattern.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.copy_.default and len(node.users) == 0:
aten_pattern.graph.erase_node(node)
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
return aten_pattern
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
@ -370,8 +380,8 @@ def _replace_literals_with_new_placeholders(
return x - 3
example_inputs = (torch.randn(1, 3, 3, 3),)
pattern_gm = get_aten_graph_module(pattern, example_inputs)
replacement_gm = get_aten_graph_module(pattern, example_inptus)
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x):
@ -456,8 +466,8 @@ def _replace_literals_with_existing_placeholders(
-128,
127,
)
pattern_gm = get_aten_graph_module(pattern, example_inputs)
replacement_gm = get_aten_graph_module(pattern, example_inptus)
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):

View File

@ -12,7 +12,7 @@ from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.pt2e.utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
get_aten_graph_module,
_get_aten_graph_module_for_pattern,
)
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
@ -469,7 +469,7 @@ def _do_annotate_conv_bn(
# Match against all conv dimensions and cuda variants
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
pattern = get_pattern(conv_fn, relu_is_inplace)
pattern = get_aten_graph_module(pattern, example_inputs, is_cuda)
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
pattern.graph.eliminate_dead_code()
pattern.recompile()
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)