Files
pytorch/torch/ao/quantization/pt2e/utils.py
andrewor14 04c1e07fd7 [quant] Move dropout replacement to move_model_to_eval (#108184) (#108255)
Summary: This commit adds a public facing
`torch.ao.quantization.move_model_to_eval` util function
for QAT users. Instead of calling model.eval() on an exported
model (which doesn't work, see
https://github.com/pytorch/pytorch/issues/103681), the user
would call this new util function instead. This ensures special
ops such as dropout and batchnorm (not supported yet) will have
the right behavior when the graph is later used for inference.

Note: Support for an equivalent `move_model_to_train` will be
added in the future. This is difficult to do for dropout
currently because the eval pattern of dropout is simply a clone
op, which we cannot just match and replace with a dropout op.

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_move_model_to_eval

Reviewers: jerryzh168, kimishpatel

Subscribers: jerryzh168, kimishpatel, supriyar

Differential Revision: [D48814735](https://our.internmc.facebook.com/intern/diff/D48814735)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108184
Approved by: https://github.com/jerryzh168
2023-09-05 13:42:37 -07:00

404 lines
16 KiB
Python

import torch
from torch.fx import (
GraphModule,
Node,
)
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
import torch.nn.functional as F
from torch.nn.utils.fusion import fuse_conv_bn_weights
import operator
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from torch.utils._pytree import LeafSpec
__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"move_model_to_eval",
"remove_tensor_overload_for_qdq_ops",
]
def _get_tensor_constant_from_node(node, m):
if node is None:
return None
assert node.op == "get_attr"
return getattr(m, node.target)
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
all_args = []
for i, schema in enumerate(args_schema):
if schema.name in orig_kwargs:
all_args.append(orig_kwargs[schema.name])
elif not schema.kwarg_only and i < len(orig_args):
all_args.append(orig_args[i])
else:
all_args.append(schema.default_value)
return all_args
def fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
conv_bias_node: Optional[Node],
bn_node: Node,
m: GraphModule
) -> None:
# conv2d args: input, weight, bias, stride, padding, dilation, ...
# Note: this should also work for conv1d, conv3d and transposed conv1-3d as well with
# easy tweaks
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
transpose = not (conv_node.target == torch.ops.aten.conv2d.default)
# TODO(Leslie): WA to support both graph capture of `torch._export.capture_pre_autograd_graph`
# and `torch._dynamo_export` for 2.1 release, remove it after formal support of new graph capture
# API in Inductor for X86.
if conv_node.target == torch.ops.aten.convolution.default:
assert type(conv_node.args[6]) is bool
transpose = conv_node.args[6]
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
eps_arg_index = 6
elif bn_node.target == torch.ops.aten._native_batch_norm_legit.default:
eps_arg_index = 7
else:
raise ValueError("BN node target is unexpected ", bn_node.target)
bn_eps = bn_args[eps_arg_index]
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
# update the weight and bias for conv
conv_args = list(conv_node.args)
# TODO(Leslie): Remove the check of node target after formal support of new graph capture
# API `torch._export.capture_pre_autograd_graph` in Inductor for X86.
# filling in the default bias argument
if len(conv_args) == 2 and (conv_node.target == torch.ops.aten.conv2d.default):
conv_args.append(None)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_weight_node.target
assert isinstance(weight_attr_name, str)
setattr(m, weight_attr_name, fused_weight)
if conv_bias_node is not None:
bias_attr_name = conv_bias_node.target
else:
bias_attr_name = weight_attr_name + "_bias"
with m.graph.inserting_before(conv_node):
get_bias_node = m.graph.get_attr(bias_attr_name)
# NOTE: here we assume the bias of conv is not quantized!
conv_args[2] = get_bias_node
setattr(m, bias_attr_name, fused_bias) # type: ignore[arg-type]
conv_node.args = tuple(conv_args)
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_node.users:
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
continue
user.replace_all_uses_with(conv_node)
# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
for n in m.graph.nodes:
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
continue
bn_node = n
n = bn_node.args[0]
# TODO(Leslie): Remove the check of node target torch.ops.aten.convolution.default after formal
# support of new graph capture API `torch._export.capture_pre_autograd_graph` in Inductor for X86.
if n.op != "call_function" or (
n.target != torch.ops.aten.conv2d.default
and n.target != torch.ops.aten.convolution.default
):
continue
conv_node = n
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
m.graph.eliminate_dead_code()
m.recompile()
def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
nn_module_stack = n.meta.get("nn_module_stack", None)
current_scope = ("", type(None))
if nn_module_stack:
bt = list(nn_module_stack.values())[-1]
current_scope = (bt[0].split(".")[-1], bt[1])
node_name_to_scope[n.name] = current_scope
return node_name_to_scope
def get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
**kwargs,
) -> GraphModule:
"""
Convert the pattern to an FX graph with decomposed aten ops.
"""
# Avoid circular dependencies
from torch._export import capture_pre_autograd_graph
aten_pattern = capture_pre_autograd_graph(
pattern,
example_inputs,
kwargs,
)
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
return aten_pattern
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
""" Remove .tensor overload for quantize/dequantize ops so that we can
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
"""
_MAP = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
}
for n in match_pattern.graph.nodes:
if n.op != "call_function":
continue
if n.target in _MAP:
n.target = _MAP[n.target]
def _replace_dropout_for_eval(m: GraphModule):
"""
Replace the aten training dropout pattern with a noop, intended for eval.
For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
effectively turns these dropout ops into noops. For exported models, however,
this is not done automatically, since the aten dropout patterns previously generated
for training remain in the graph. Here we rewrite these dropout patterns with noops
to avoid incorrectly applying further dropout during eval.
See https://github.com/pytorch/pytorch/issues/103681.
"""
def dropout_train(x):
return F.dropout(x, p=0.5, training=True)
def dropout_eval(x):
return F.dropout(x, p=0.5, training=False)
example_inputs = (torch.randn(1),)
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[],
ignore_literals=True,
)
m.recompile()
def _is_literal(arg):
if isinstance(arg, (int, float)):
return True
if isinstance(arg, (tuple, list)):
return all(map(_is_literal, arg))
return False
def _replace_literals_with_new_placeholders(
gm: torch.fx.GraphModule,
merge_dup: bool = False,
exclude_literals: Optional[List[Any]] = None
):
"""Replace the literals in the graph with placeholder nodes that's created on the fly while we
traverse the graph, so that the literal arguments in the graph can be matched and replaced
To use this, the pattern and replacement graph should have the exact same number of literal args
and they should be used in the exact same order in the pattern and replacement graph.
If the literal arguments are not used in the same order in pattern and replacement graph, please
use `_replace_literals_with_existing_placeholders` instead
Args:
`gm`: input GraphModule that we'll transform
`merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
the graph, whether they should correspond to the same placeholder or not
`exclude_literals`: a list of literals that will not be replaced with placeholders
Example:
# 1. Original Graph
def pattern(self, x):
return x + 3
def replacement(self, x):
return x - 3
example_inputs = (torch.randn(1, 3, 3, 3),)
pattern_gm = get_aten_graph_module(pattern, example_inputs)
replacement_gm = get_aten_graph_module(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x):
return x + 3
def replacement(self, x):
return x - 3
pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
# 3. After replacing literals with new placeholder nodes
def pattern(self, x, new_ph):
return x + new_ph
def pattern(self, x, new_ph):
return x - new_ph
"""
last_ph = None
cnt = 0
literal_to_ph: Dict[Union[float, bool, int, torch.dtype], Node] = {}
if exclude_literals is None:
exclude_literals = []
for node in gm.graph.nodes:
if node.op == "placeholder":
last_ph = node
cnt += 1
continue
with gm.graph.inserting_after(last_ph):
new_args = []
for arg in node.args:
if _is_literal(arg) and arg not in exclude_literals:
if merge_dup and arg in literal_to_ph:
new_args.append(literal_to_ph[arg])
else:
ph_node = gm.graph.placeholder("arg" + str(cnt))
new_args.append(ph_node)
gm._in_spec.children_specs[0].children_specs.append(LeafSpec())
cnt += 1
if merge_dup:
literal_to_ph[arg] = ph_node
else:
new_args.append(arg)
new_args = tuple(new_args)
node.args = new_args
return gm
def _replace_literals_with_existing_placeholders(
gm: torch.fx.GraphModule,
exclude_literals: Optional[List[Any]] = None,
literal_to_ph_idx: Optional[Dict[Union[float, int, bool, torch.dtype], int]] = None
):
"""Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
in the graph can be matched and replaced
To use this, all literal args in the graph should be unique and each of them should correspond
to exactly one placeholder node
# 1. Original Graph
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
example_inputs = (
torch.randn(1, 3, 3, 3),
1.0,
0,
-128,
127,
)
pattern_gm = get_aten_graph_module(pattern, example_inputs)
replacement_gm = get_aten_graph_module(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
x_i8 = torch.clamp(x_i8, -128, 127)
return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
# Note that literal args appear in different order in pattern and replacement graph, so
# we can't use _replace_literals_with_new_placeholders
literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
# 3. After replacing literals with existing placeholder nodes
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
"""
if exclude_literals is None:
exclude_literals = []
if literal_to_ph_idx is None:
literal_to_ph_idx = {}
phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
for node in gm.graph.nodes:
if node.op != "call_function":
continue
new_args = []
for arg in node.args:
if _is_literal(arg) and arg not in exclude_literals and arg in literal_to_ph_idx:
ph_idx = literal_to_ph_idx[arg]
ph_node = phs[ph_idx]
new_args.append(ph_node)
else:
new_args.append(arg)
new_args = tuple(new_args)
node.args = new_args
return gm
# TODO: also support move_model_to_train
# TODO: also support standalone batchnorm
def move_model_to_eval(m: GraphModule):
"""
Move an exported GraphModule to eval mode.
This is equivalent to model.eval() but only for certain special ops like dropout.
QAT users should call this before performing inference on the model.
"""
_replace_dropout_for_eval(m)
return m