mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland][quant][pt2e] Change input act annotation to a map and allow dynamic quantization for non zeroth argument (#101005) (#101041)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/101005 Previously the node annotation looks like the following: ``` node.meta["..."] = { "input_act_obs_or_fq_ctr": ..., "weight_obs_or_fq_ctr": ..., "weight_index": 1, } ``` Basically we need specifiy the index for weight and also have a separate key for weight config, in this PR we changed that to: ``` node.meta["..."] = { "input_act_obs_or_fq_ctr_map": {input_node: ..., weight_node: ...}, } ``` This can support specifying the observer/fake quant constructor for any argument of the node Test Plan: buck2 test @//mode/opt //caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_resnet18_with_quantizer_api (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2EModels)' Differential Revision: D45719781 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101041 Approved by: https://github.com/andrewor14
This commit is contained in:
committed by
PyTorch MergeBot
parent
3941bbc5ba
commit
058d740f59
@ -1,7 +1,11 @@
|
||||
import torch
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.ao.quantization.fx.prepare import (
|
||||
_maybe_insert_input_observers_for_node,
|
||||
_needs_obs_or_fq,
|
||||
_get_arg_as_input_act_obs_or_fq_ctr,
|
||||
_get_output_act_obs_or_fq_ctr,
|
||||
_get_dtype_and_is_dynamic,
|
||||
_insert_observer,
|
||||
_maybe_insert_output_observer_for_node,
|
||||
_is_observer_in_same_graph,
|
||||
_maybe_make_input_output_share_observers,
|
||||
@ -9,12 +13,134 @@ from torch.ao.quantization.fx.prepare import (
|
||||
_maybe_insert_observers_before_graph_output,
|
||||
_save_state
|
||||
)
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import Node
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
Node,
|
||||
)
|
||||
from torch.fx.node import Argument
|
||||
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, Union, Any
|
||||
|
||||
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node: Union[Node, Any],
|
||||
arg: Argument,
|
||||
qconfig: QConfigAny,
|
||||
model: torch.nn.Module,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
) -> Argument:
|
||||
"""
|
||||
Given a `node` and an `arg`, inserts an input observer between
|
||||
`node` and `arg` if necessary.
|
||||
"""
|
||||
# for ops such as torch.cat([x0, x1]),
|
||||
# traverse through the list
|
||||
if isinstance(arg, (list, tuple)):
|
||||
new_arg_to_return = []
|
||||
for inner_arg in arg:
|
||||
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node, inner_arg, qconfig, model, named_modules,
|
||||
)
|
||||
new_arg_to_return.append(new_inner_arg)
|
||||
return type(arg)(new_arg_to_return)
|
||||
|
||||
if not isinstance(arg, Node):
|
||||
return arg
|
||||
assert isinstance(arg, Node)
|
||||
# default (no observer)
|
||||
new_arg = arg
|
||||
|
||||
# TODO: we are assuming "target_dtype_info" exists here, maybe
|
||||
# a default value also need to be provided here
|
||||
target_dtype_info = node.meta["target_dtype_info"]
|
||||
# for nodes that doesn't have `reuse_input_obs_or_fq` configured,
|
||||
# we'll default to False, this makes configuring this field optional for users
|
||||
reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
|
||||
arg_as_input_act_obs_or_fq_ctr = _get_arg_as_input_act_obs_or_fq_ctr(arg, node, named_modules)
|
||||
act_post_process_ctr = arg_as_input_act_obs_or_fq_ctr
|
||||
|
||||
arg_as_output_act_obs_or_fq_ctr = _get_output_act_obs_or_fq_ctr(arg, named_modules)
|
||||
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq_ctr)
|
||||
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq_ctr)
|
||||
|
||||
|
||||
needs_obs_or_fq = _needs_obs_or_fq(
|
||||
arg_as_output_target_dtype,
|
||||
arg_as_output_target_is_dynamic,
|
||||
arg_as_input_target_dtype,
|
||||
arg_as_input_target_is_dynamic,
|
||||
reuse_input_obs_or_fq,
|
||||
True, # set is_zeroth_arg to True so that non-zeroth arg can be observed for
|
||||
# dynamic quantization as well
|
||||
)
|
||||
|
||||
if needs_obs_or_fq:
|
||||
|
||||
new_obs_mod = act_post_process_ctr()
|
||||
existing_obs_node = None
|
||||
|
||||
# Before using the new observer, check if an observer
|
||||
# of the correct type already exists. If it does, use it.
|
||||
# This prevents duplicate observer insertions if a node is
|
||||
# used by multiple nodes.
|
||||
# TODO: this is looking into how the value is used in the future
|
||||
# we should remove this
|
||||
# removing this means we insert one observer for each use, even if they
|
||||
# have the same dtype, we can have an extra pass that removes the extra observers
|
||||
for maybe_obs_node, _ in arg.users.items():
|
||||
if maybe_obs_node.op == 'call_module':
|
||||
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
||||
if (
|
||||
type(maybe_obs_mod) == type(new_obs_mod) and
|
||||
maybe_obs_mod.dtype == arg_as_input_target_dtype
|
||||
):
|
||||
existing_obs_node = maybe_obs_node
|
||||
break
|
||||
|
||||
if existing_obs_node is None:
|
||||
new_obs_node = _insert_observer(
|
||||
arg, new_obs_mod, model, named_modules, model.graph) # type: ignore[arg-type]
|
||||
# override this arg to be the observed arg
|
||||
new_arg = new_obs_node
|
||||
else:
|
||||
new_arg = existing_obs_node
|
||||
|
||||
return new_arg
|
||||
|
||||
def _maybe_insert_input_observers_for_node(
|
||||
node: Node,
|
||||
qconfig: QConfigAny,
|
||||
model: torch.nn.Module,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
) -> None:
|
||||
"""
|
||||
If needed, inserts observers to the input args and kwargs of `node`.
|
||||
Note: modifies `node` inplace.
|
||||
|
||||
For example, if cur_node needs an observer after prev_node, we change from
|
||||
|
||||
prev_node -> cur_node
|
||||
|
||||
To
|
||||
|
||||
prev_node -> obs -> cur_node
|
||||
|
||||
"""
|
||||
# Look through every input arg. If that arg's target dtype does not
|
||||
# match the current node's target dtype, insert an observer.
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node, arg, qconfig, model, named_modules,
|
||||
)
|
||||
new_args.append(new_arg)
|
||||
|
||||
assert len(node.kwargs) == 0, " expecting kwargs for aten op IR to be empty"
|
||||
|
||||
# assign the new args to the node, inplace
|
||||
node.args = tuple(new_args)
|
||||
|
||||
def _maybe_insert_input_and_output_observers_for_node(
|
||||
node: Node,
|
||||
@ -43,10 +169,6 @@ def _maybe_insert_input_and_output_observers_for_node(
|
||||
None, # qconfig
|
||||
model,
|
||||
named_modules,
|
||||
model.graph,
|
||||
None, # qhandler
|
||||
PrepareCustomConfig(),
|
||||
None, # backend_config
|
||||
)
|
||||
|
||||
# this returns the new observer node if it was needed
|
||||
|
@ -406,14 +406,15 @@ class QNNPackQuantizer(Quantizer):
|
||||
if _is_annotated([relu_node, conv_node]):
|
||||
return
|
||||
|
||||
input_node = conv_node.args[0]
|
||||
weight_node = conv_node.args[1]
|
||||
bias_node = conv_node.args[2]
|
||||
conv_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": get_act_obs_or_fq_ctr(quantization_config),
|
||||
"weight_obs_or_fq_ctr": get_weight_obs_or_fq_ctr(quantization_config),
|
||||
"bias_obs_or_fq_ctr": get_bias_obs_or_fq_ctr(quantization_config),
|
||||
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
|
||||
"weight_index": 1,
|
||||
# TODO: validation of bias_index must be set if bias_obs_or_fq_ctr is set
|
||||
"bias_index": 2,
|
||||
"input_act_obs_or_fq_ctr_map": {
|
||||
input_node: get_act_obs_or_fq_ctr(quantization_config),
|
||||
weight_node: get_weight_obs_or_fq_ctr(quantization_config),
|
||||
bias_node: get_bias_obs_or_fq_ctr(quantization_config),
|
||||
},
|
||||
"_annotated": True,
|
||||
}
|
||||
relu_node.meta["target_dtype_info"] = {
|
||||
@ -433,15 +434,17 @@ class QNNPackQuantizer(Quantizer):
|
||||
# skip annotation if it is already annotated
|
||||
if _is_annotated([conv_node]):
|
||||
return
|
||||
|
||||
input_node = conv_node.args[0]
|
||||
weight_node = conv_node.args[1]
|
||||
bias_node = conv_node.args[2]
|
||||
conv_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": get_act_obs_or_fq_ctr(quantization_config),
|
||||
"weight_obs_or_fq_ctr": get_weight_obs_or_fq_ctr(quantization_config),
|
||||
"bias_obs_or_fq_ctr": get_bias_obs_or_fq_ctr(quantization_config),
|
||||
"input_act_obs_or_fq_ctr_map": {
|
||||
input_node: get_act_obs_or_fq_ctr(quantization_config),
|
||||
weight_node: get_weight_obs_or_fq_ctr(quantization_config),
|
||||
bias_node: get_bias_obs_or_fq_ctr(quantization_config),
|
||||
},
|
||||
"output_act_obs_or_fq_ctr": get_act_obs_or_fq_ctr(quantization_config),
|
||||
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
|
||||
"weight_index": 1,
|
||||
# TODO: validation of bias_index must be set if bias_obs_or_fq_ctr is set
|
||||
"bias_index": 2,
|
||||
"_annotated": True,
|
||||
}
|
||||
|
||||
|
@ -144,10 +144,9 @@ def _needs_obs_or_fq(
|
||||
|
||||
is_zeroth_arg: we only dynamically quantize the first arg of the node right now
|
||||
this should be removed when we enable configuring dynamic quantization
|
||||
for a specific argument
|
||||
for a specific argument, this can be removed if we deprecate fx graph mode
|
||||
quantization
|
||||
|
||||
Note: we want to refactor the annotation specification api soon to use
|
||||
a map from user_node to obs_or_fq_ctr
|
||||
"""
|
||||
|
||||
# need to insert placeholder observer for dynamic quantization so that it can
|
||||
@ -512,120 +511,83 @@ def _get_target_activation_dtype_for_node(
|
||||
}
|
||||
return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
||||
|
||||
def _get_output_act_obs_or_fq_ctr(
|
||||
arg: Node,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
) -> Any:
|
||||
""" Get the constructor for observer or fake quant object for
|
||||
the argument in the original graph as the output of previous node,
|
||||
skipping inserted observers
|
||||
|
||||
We are assuming that the observers are inserted correctly, and the dtype for
|
||||
argument in quantized graph will match what is specified by the qconfig
|
||||
"""
|
||||
assert isinstance(arg, Node)
|
||||
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
|
||||
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
|
||||
# Since we modified the graph in this case, we must trace back from the args through
|
||||
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
|
||||
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
|
||||
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
|
||||
output_act_obs_or_fq_ctr = None
|
||||
if custom_module_lstm_node is not None:
|
||||
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||||
elif _is_activation_post_process_node(arg, named_modules):
|
||||
observed_arg = arg.args[0]
|
||||
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
|
||||
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||||
else:
|
||||
if "target_dtype_info" in arg.meta:
|
||||
output_act_obs_or_fq_ctr = \
|
||||
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
else:
|
||||
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
|
||||
return output_act_obs_or_fq_ctr
|
||||
|
||||
def _get_arg_target_dtype_as_output(
|
||||
arg: Node,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
) -> Optional[Union[torch.dtype, type]]:
|
||||
""" Get the target output activation dtype for
|
||||
the argument in the original graph, skipping inserted observers
|
||||
We are assuming that the observers are inserted correctly, and the dtype for
|
||||
argument in quantized graph will match what is specified by the qconfig
|
||||
"""
|
||||
assert isinstance(arg, Node)
|
||||
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
|
||||
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
|
||||
# Since we modified the graph in this case, we must trace back from the args through
|
||||
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
|
||||
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
|
||||
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
|
||||
output_act_obs_or_fq_ctr = None
|
||||
if custom_module_lstm_node is not None:
|
||||
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||||
elif _is_activation_post_process_node(arg, named_modules):
|
||||
observed_arg = arg.args[0]
|
||||
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
|
||||
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||||
else:
|
||||
if "target_dtype_info" in arg.meta:
|
||||
output_act_obs_or_fq_ctr = \
|
||||
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
else:
|
||||
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
|
||||
output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
|
||||
# TODO: should support is_dynamic here as well
|
||||
return output_act_dtype
|
||||
) -> Optional[torch.dtype]:
|
||||
arg_as_output_act_obs_or_fq_ctr = _get_output_act_obs_or_fq_ctr(arg, named_modules)
|
||||
arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq_ctr)
|
||||
return arg_as_output_target_dtype
|
||||
|
||||
|
||||
# TODO: merge with _get_arg_target_dtype_as_output
|
||||
def _get_arg_target_is_dynamic_as_output(
|
||||
arg: Node,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
) -> bool:
|
||||
""" Get the target output activation dtype for
|
||||
the argument in the original graph, skipping inserted observers
|
||||
We are assuming that the observers are inserted correctly, and the dtype for
|
||||
argument in quantized graph will match what is specified by the qconfig
|
||||
"""
|
||||
assert isinstance(arg, Node)
|
||||
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
|
||||
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
|
||||
# Since we modified the graph in this case, we must trace back from the args through
|
||||
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
|
||||
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
|
||||
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
|
||||
output_act_obs_or_fq_ctr = None
|
||||
if custom_module_lstm_node is not None:
|
||||
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||||
elif _is_activation_post_process_node(arg, named_modules):
|
||||
observed_arg = arg.args[0]
|
||||
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
|
||||
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||||
else:
|
||||
if "target_dtype_info" in arg.meta:
|
||||
output_act_obs_or_fq_ctr = \
|
||||
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
else:
|
||||
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
|
||||
_, output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
|
||||
# return output_is_dynamic
|
||||
return False
|
||||
|
||||
def _get_arg_target_dtype_as_input_to_node(
|
||||
def _get_arg_as_input_act_obs_or_fq_ctr(
|
||||
arg: Node,
|
||||
node: Node,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
) -> Optional[Union[torch.dtype, type]]:
|
||||
""" Get the target argument dtype for the argument `arg`, as input
|
||||
to node `node`
|
||||
) -> Any:
|
||||
""" Get the observer or fake quant constructor for the Argument `arg`, as input
|
||||
to Node `node`
|
||||
"""
|
||||
assert isinstance(arg, Node)
|
||||
# "input_act_obs_or_fq_ctr_map" is the more general design we'll use for pt2e path
|
||||
# it is a map from input argument node to observer or fake quant constructor, for example
|
||||
# for the following graph:
|
||||
# x -> conv -> output
|
||||
#
|
||||
# we may annotate conv node like the following:
|
||||
# conv.meta[...] = {"input_act_obs_or_fq_ctr_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...}
|
||||
#
|
||||
if "target_dtype_info" in node.meta and "input_act_obs_or_fq_ctr_map" in node.meta["target_dtype_info"]:
|
||||
input_act_obs_or_fq_ctr = \
|
||||
node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr_map"].get(arg, _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
return input_act_obs_or_fq_ctr
|
||||
|
||||
# we can remove the following path in the future if fx graph mode quantization is
|
||||
# no longer used
|
||||
is_weight = node_arg_is_weight(node, arg)
|
||||
is_bias = node_arg_is_bias(node, arg)
|
||||
is_activation = not is_weight and not is_bias
|
||||
obs_or_fq_ctr = None
|
||||
if is_activation:
|
||||
input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
qconfig_dtype, _ = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
|
||||
return qconfig_dtype
|
||||
obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
elif is_weight:
|
||||
if node.target in NON_QUANTIZABLE_WEIGHT_OPS:
|
||||
return None
|
||||
else:
|
||||
weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq_ctr)
|
||||
return qconfig_weight_dtype
|
||||
if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
|
||||
obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
else:
|
||||
bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq_ctr)
|
||||
return qconfig_bias_dtype
|
||||
|
||||
def _get_arg_target_is_dynamic_as_input_to_node(
|
||||
arg: Node,
|
||||
node: Node,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
) -> bool:
|
||||
""" Get the target argument dtype for the argument `arg`, as input
|
||||
to node `node`
|
||||
"""
|
||||
assert isinstance(arg, Node)
|
||||
is_weight = node_arg_is_weight(node, arg)
|
||||
is_bias = node_arg_is_bias(node, arg)
|
||||
is_activation = not is_weight and not is_bias
|
||||
if is_activation and "input_act_obs_or_fq_ctr" in node.meta["target_dtype_info"]:
|
||||
input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
_, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
|
||||
return qconfig_is_dynamic
|
||||
else:
|
||||
return False
|
||||
obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
return obs_or_fq_ctr
|
||||
|
||||
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node: Union[Node, Any],
|
||||
@ -668,7 +630,6 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
# Note: qconfig can be None in this branch this we are getting act/fq from
|
||||
# node.meta now
|
||||
# regular flow for most nodes, except standalone modules
|
||||
is_weight = node_arg_is_weight(node, arg)
|
||||
|
||||
# TODO: we are assuming "target_dtype_info" exists here, maybe
|
||||
# a default value also need to be provided here
|
||||
@ -676,16 +637,14 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
# for nodes that doesn't have `reuse_input_obs_or_fq` configured,
|
||||
# we'll default to False, this makes configuring this field optional for users
|
||||
reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
|
||||
arg_as_input_act_obs_or_fq_ctr = _get_arg_as_input_act_obs_or_fq_ctr(arg, node, named_modules)
|
||||
act_post_process_ctr = arg_as_input_act_obs_or_fq_ctr
|
||||
|
||||
arg_as_output_act_obs_or_fq_ctr = _get_output_act_obs_or_fq_ctr(arg, named_modules)
|
||||
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq_ctr)
|
||||
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq_ctr)
|
||||
|
||||
# TODO: check bias as well?
|
||||
obs_or_fq_ctr_key = "weight_obs_or_fq_ctr" if is_weight else "input_act_obs_or_fq_ctr"
|
||||
act_post_process_ctr = target_dtype_info.get(obs_or_fq_ctr_key, _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||||
|
||||
arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules)
|
||||
arg_as_output_target_is_dynamic = _get_arg_target_is_dynamic_as_output(arg, named_modules)
|
||||
arg_as_input_target_dtype = _get_arg_target_dtype_as_input_to_node(arg, node, named_modules)
|
||||
arg_as_input_target_is_dynamic = \
|
||||
_get_arg_target_is_dynamic_as_input_to_node(arg, node, named_modules) # type: ignore[arg-type]
|
||||
needs_obs_or_fq = _needs_obs_or_fq(
|
||||
arg_as_output_target_dtype,
|
||||
arg_as_output_target_is_dynamic,
|
||||
@ -860,6 +819,9 @@ def _maybe_insert_output_observer_for_node(
|
||||
and returns it.
|
||||
|
||||
If `node` does not need an output observer, returns None.
|
||||
|
||||
Note: inserting dynamic quantization ops for output is not supported in fx graph mode
|
||||
quantization code path right now
|
||||
"""
|
||||
assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
|
||||
|
||||
|
Reference in New Issue
Block a user