[reland][quant][pt2e] Change input act annotation to a map and allow dynamic quantization for non zeroth argument (#101005) (#101041)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101005

Previously the node annotation looks like the following:
```
node.meta["..."] = {
    "input_act_obs_or_fq_ctr": ...,
    "weight_obs_or_fq_ctr": ...,
    "weight_index": 1,
}
```
Basically we need specifiy the index for weight and also have a separate key for weight config, in this PR we changed that to:
```
node.meta["..."] = {
    "input_act_obs_or_fq_ctr_map": {input_node: ..., weight_node: ...},
}
```
This can support specifying the observer/fake quant constructor for any argument of the node

Test Plan: buck2 test @//mode/opt //caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_resnet18_with_quantizer_api (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2EModels)'

Differential Revision: D45719781

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101041
Approved by: https://github.com/andrewor14
This commit is contained in:
Jerry Zhang
2023-05-10 17:43:21 +00:00
committed by PyTorch MergeBot
parent 3941bbc5ba
commit 058d740f59
3 changed files with 220 additions and 133 deletions

View File

@ -1,7 +1,11 @@
import torch
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.prepare import (
_maybe_insert_input_observers_for_node,
_needs_obs_or_fq,
_get_arg_as_input_act_obs_or_fq_ctr,
_get_output_act_obs_or_fq_ctr,
_get_dtype_and_is_dynamic,
_insert_observer,
_maybe_insert_output_observer_for_node,
_is_observer_in_same_graph,
_maybe_make_input_output_share_observers,
@ -9,12 +13,134 @@ from torch.ao.quantization.fx.prepare import (
_maybe_insert_observers_before_graph_output,
_save_state
)
from torch.fx import GraphModule
from torch.fx import Node
from torch.fx import (
GraphModule,
Node,
)
from torch.fx.node import Argument
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from typing import Dict, Tuple
from typing import Dict, Tuple, Union, Any
def _maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
arg: Argument,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
`node` and `arg` if necessary.
"""
# for ops such as torch.cat([x0, x1]),
# traverse through the list
if isinstance(arg, (list, tuple)):
new_arg_to_return = []
for inner_arg in arg:
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node, inner_arg, qconfig, model, named_modules,
)
new_arg_to_return.append(new_inner_arg)
return type(arg)(new_arg_to_return)
if not isinstance(arg, Node):
return arg
assert isinstance(arg, Node)
# default (no observer)
new_arg = arg
# TODO: we are assuming "target_dtype_info" exists here, maybe
# a default value also need to be provided here
target_dtype_info = node.meta["target_dtype_info"]
# for nodes that doesn't have `reuse_input_obs_or_fq` configured,
# we'll default to False, this makes configuring this field optional for users
reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
arg_as_input_act_obs_or_fq_ctr = _get_arg_as_input_act_obs_or_fq_ctr(arg, node, named_modules)
act_post_process_ctr = arg_as_input_act_obs_or_fq_ctr
arg_as_output_act_obs_or_fq_ctr = _get_output_act_obs_or_fq_ctr(arg, named_modules)
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq_ctr)
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq_ctr)
needs_obs_or_fq = _needs_obs_or_fq(
arg_as_output_target_dtype,
arg_as_output_target_is_dynamic,
arg_as_input_target_dtype,
arg_as_input_target_is_dynamic,
reuse_input_obs_or_fq,
True, # set is_zeroth_arg to True so that non-zeroth arg can be observed for
# dynamic quantization as well
)
if needs_obs_or_fq:
new_obs_mod = act_post_process_ctr()
existing_obs_node = None
# Before using the new observer, check if an observer
# of the correct type already exists. If it does, use it.
# This prevents duplicate observer insertions if a node is
# used by multiple nodes.
# TODO: this is looking into how the value is used in the future
# we should remove this
# removing this means we insert one observer for each use, even if they
# have the same dtype, we can have an extra pass that removes the extra observers
for maybe_obs_node, _ in arg.users.items():
if maybe_obs_node.op == 'call_module':
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if (
type(maybe_obs_mod) == type(new_obs_mod) and
maybe_obs_mod.dtype == arg_as_input_target_dtype
):
existing_obs_node = maybe_obs_node
break
if existing_obs_node is None:
new_obs_node = _insert_observer(
arg, new_obs_mod, model, named_modules, model.graph) # type: ignore[arg-type]
# override this arg to be the observed arg
new_arg = new_obs_node
else:
new_arg = existing_obs_node
return new_arg
def _maybe_insert_input_observers_for_node(
node: Node,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Note: modifies `node` inplace.
For example, if cur_node needs an observer after prev_node, we change from
prev_node -> cur_node
To
prev_node -> obs -> cur_node
"""
# Look through every input arg. If that arg's target dtype does not
# match the current node's target dtype, insert an observer.
new_args = []
for arg in node.args:
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node, arg, qconfig, model, named_modules,
)
new_args.append(new_arg)
assert len(node.kwargs) == 0, " expecting kwargs for aten op IR to be empty"
# assign the new args to the node, inplace
node.args = tuple(new_args)
def _maybe_insert_input_and_output_observers_for_node(
node: Node,
@ -43,10 +169,6 @@ def _maybe_insert_input_and_output_observers_for_node(
None, # qconfig
model,
named_modules,
model.graph,
None, # qhandler
PrepareCustomConfig(),
None, # backend_config
)
# this returns the new observer node if it was needed

View File

@ -406,14 +406,15 @@ class QNNPackQuantizer(Quantizer):
if _is_annotated([relu_node, conv_node]):
return
input_node = conv_node.args[0]
weight_node = conv_node.args[1]
bias_node = conv_node.args[2]
conv_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": get_act_obs_or_fq_ctr(quantization_config),
"weight_obs_or_fq_ctr": get_weight_obs_or_fq_ctr(quantization_config),
"bias_obs_or_fq_ctr": get_bias_obs_or_fq_ctr(quantization_config),
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
"weight_index": 1,
# TODO: validation of bias_index must be set if bias_obs_or_fq_ctr is set
"bias_index": 2,
"input_act_obs_or_fq_ctr_map": {
input_node: get_act_obs_or_fq_ctr(quantization_config),
weight_node: get_weight_obs_or_fq_ctr(quantization_config),
bias_node: get_bias_obs_or_fq_ctr(quantization_config),
},
"_annotated": True,
}
relu_node.meta["target_dtype_info"] = {
@ -433,15 +434,17 @@ class QNNPackQuantizer(Quantizer):
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
return
input_node = conv_node.args[0]
weight_node = conv_node.args[1]
bias_node = conv_node.args[2]
conv_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": get_act_obs_or_fq_ctr(quantization_config),
"weight_obs_or_fq_ctr": get_weight_obs_or_fq_ctr(quantization_config),
"bias_obs_or_fq_ctr": get_bias_obs_or_fq_ctr(quantization_config),
"input_act_obs_or_fq_ctr_map": {
input_node: get_act_obs_or_fq_ctr(quantization_config),
weight_node: get_weight_obs_or_fq_ctr(quantization_config),
bias_node: get_bias_obs_or_fq_ctr(quantization_config),
},
"output_act_obs_or_fq_ctr": get_act_obs_or_fq_ctr(quantization_config),
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
"weight_index": 1,
# TODO: validation of bias_index must be set if bias_obs_or_fq_ctr is set
"bias_index": 2,
"_annotated": True,
}

View File

@ -144,10 +144,9 @@ def _needs_obs_or_fq(
is_zeroth_arg: we only dynamically quantize the first arg of the node right now
this should be removed when we enable configuring dynamic quantization
for a specific argument
for a specific argument, this can be removed if we deprecate fx graph mode
quantization
Note: we want to refactor the annotation specification api soon to use
a map from user_node to obs_or_fq_ctr
"""
# need to insert placeholder observer for dynamic quantization so that it can
@ -512,120 +511,83 @@ def _get_target_activation_dtype_for_node(
}
return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
def _get_output_act_obs_or_fq_ctr(
arg: Node,
named_modules: Dict[str, torch.nn.Module],
) -> Any:
""" Get the constructor for observer or fake quant object for
the argument in the original graph as the output of previous node,
skipping inserted observers
We are assuming that the observers are inserted correctly, and the dtype for
argument in quantized graph will match what is specified by the qconfig
"""
assert isinstance(arg, Node)
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
# Since we modified the graph in this case, we must trace back from the args through
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
output_act_obs_or_fq_ctr = None
if custom_module_lstm_node is not None:
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
elif _is_activation_post_process_node(arg, named_modules):
observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
else:
if "target_dtype_info" in arg.meta:
output_act_obs_or_fq_ctr = \
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
else:
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
return output_act_obs_or_fq_ctr
def _get_arg_target_dtype_as_output(
arg: Node,
named_modules: Dict[str, torch.nn.Module],
) -> Optional[Union[torch.dtype, type]]:
""" Get the target output activation dtype for
the argument in the original graph, skipping inserted observers
We are assuming that the observers are inserted correctly, and the dtype for
argument in quantized graph will match what is specified by the qconfig
"""
assert isinstance(arg, Node)
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
# Since we modified the graph in this case, we must trace back from the args through
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
output_act_obs_or_fq_ctr = None
if custom_module_lstm_node is not None:
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
elif _is_activation_post_process_node(arg, named_modules):
observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
else:
if "target_dtype_info" in arg.meta:
output_act_obs_or_fq_ctr = \
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
else:
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
# TODO: should support is_dynamic here as well
return output_act_dtype
) -> Optional[torch.dtype]:
arg_as_output_act_obs_or_fq_ctr = _get_output_act_obs_or_fq_ctr(arg, named_modules)
arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq_ctr)
return arg_as_output_target_dtype
# TODO: merge with _get_arg_target_dtype_as_output
def _get_arg_target_is_dynamic_as_output(
arg: Node,
named_modules: Dict[str, torch.nn.Module],
) -> bool:
""" Get the target output activation dtype for
the argument in the original graph, skipping inserted observers
We are assuming that the observers are inserted correctly, and the dtype for
argument in quantized graph will match what is specified by the qconfig
"""
assert isinstance(arg, Node)
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
# Since we modified the graph in this case, we must trace back from the args through
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
output_act_obs_or_fq_ctr = None
if custom_module_lstm_node is not None:
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
elif _is_activation_post_process_node(arg, named_modules):
observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
else:
if "target_dtype_info" in arg.meta:
output_act_obs_or_fq_ctr = \
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
else:
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
_, output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
# return output_is_dynamic
return False
def _get_arg_target_dtype_as_input_to_node(
def _get_arg_as_input_act_obs_or_fq_ctr(
arg: Node,
node: Node,
named_modules: Dict[str, torch.nn.Module],
) -> Optional[Union[torch.dtype, type]]:
""" Get the target argument dtype for the argument `arg`, as input
to node `node`
) -> Any:
""" Get the observer or fake quant constructor for the Argument `arg`, as input
to Node `node`
"""
assert isinstance(arg, Node)
# "input_act_obs_or_fq_ctr_map" is the more general design we'll use for pt2e path
# it is a map from input argument node to observer or fake quant constructor, for example
# for the following graph:
# x -> conv -> output
#
# we may annotate conv node like the following:
# conv.meta[...] = {"input_act_obs_or_fq_ctr_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...}
#
if "target_dtype_info" in node.meta and "input_act_obs_or_fq_ctr_map" in node.meta["target_dtype_info"]:
input_act_obs_or_fq_ctr = \
node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr_map"].get(arg, _DEFAULT_FP32_OBS_OR_FQ_CTR)
return input_act_obs_or_fq_ctr
# we can remove the following path in the future if fx graph mode quantization is
# no longer used
is_weight = node_arg_is_weight(node, arg)
is_bias = node_arg_is_bias(node, arg)
is_activation = not is_weight and not is_bias
obs_or_fq_ctr = None
if is_activation:
input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
qconfig_dtype, _ = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
return qconfig_dtype
obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
elif is_weight:
if node.target in NON_QUANTIZABLE_WEIGHT_OPS:
return None
else:
weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq_ctr)
return qconfig_weight_dtype
if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
else:
bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq_ctr)
return qconfig_bias_dtype
def _get_arg_target_is_dynamic_as_input_to_node(
arg: Node,
node: Node,
named_modules: Dict[str, torch.nn.Module],
) -> bool:
""" Get the target argument dtype for the argument `arg`, as input
to node `node`
"""
assert isinstance(arg, Node)
is_weight = node_arg_is_weight(node, arg)
is_bias = node_arg_is_bias(node, arg)
is_activation = not is_weight and not is_bias
if is_activation and "input_act_obs_or_fq_ctr" in node.meta["target_dtype_info"]:
input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
_, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
return qconfig_is_dynamic
else:
return False
obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
return obs_or_fq_ctr
def _maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
@ -668,7 +630,6 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
# Note: qconfig can be None in this branch this we are getting act/fq from
# node.meta now
# regular flow for most nodes, except standalone modules
is_weight = node_arg_is_weight(node, arg)
# TODO: we are assuming "target_dtype_info" exists here, maybe
# a default value also need to be provided here
@ -676,16 +637,14 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
# for nodes that doesn't have `reuse_input_obs_or_fq` configured,
# we'll default to False, this makes configuring this field optional for users
reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
arg_as_input_act_obs_or_fq_ctr = _get_arg_as_input_act_obs_or_fq_ctr(arg, node, named_modules)
act_post_process_ctr = arg_as_input_act_obs_or_fq_ctr
arg_as_output_act_obs_or_fq_ctr = _get_output_act_obs_or_fq_ctr(arg, named_modules)
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq_ctr)
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq_ctr)
# TODO: check bias as well?
obs_or_fq_ctr_key = "weight_obs_or_fq_ctr" if is_weight else "input_act_obs_or_fq_ctr"
act_post_process_ctr = target_dtype_info.get(obs_or_fq_ctr_key, _DEFAULT_FP32_OBS_OR_FQ_CTR)
arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules)
arg_as_output_target_is_dynamic = _get_arg_target_is_dynamic_as_output(arg, named_modules)
arg_as_input_target_dtype = _get_arg_target_dtype_as_input_to_node(arg, node, named_modules)
arg_as_input_target_is_dynamic = \
_get_arg_target_is_dynamic_as_input_to_node(arg, node, named_modules) # type: ignore[arg-type]
needs_obs_or_fq = _needs_obs_or_fq(
arg_as_output_target_dtype,
arg_as_output_target_is_dynamic,
@ -860,6 +819,9 @@ def _maybe_insert_output_observer_for_node(
and returns it.
If `node` does not need an output observer, returns None.
Note: inserting dynamic quantization ops for output is not supported in fx graph mode
quantization code path right now
"""
assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'