mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 09:04:53 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61323 Before this PR, all observers and fake quants were silently removed when adding loggers with NS. This is problematic for QAT models because we need the fake quants to run in order to properly capture intermediate outputs. This PR fixes the issue by preserving the observers throughout the passes which add loggers. In detail: * for each quantization module or fusion, add additional patterns with that fusion and an observer/fake_quant at the end * remove the places in the logger model creation code which removed observers * add unit testing that QAT numerics do not change after adding loggers Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_loggers_preserve_qat_numerics python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_shadow_loggers_preserve_qat_numerics ``` Imported from OSS Reviewed By: hx89 Differential Revision: D29600351 fbshipit-source-id: 5f25118b79eb47860c49bca882de6a8eae7a4456
474 lines
19 KiB
Python
474 lines
19 KiB
Python
import collections
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.quantization.quantize_fx as quantize_fx
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Node
|
|
from torch.quantization.ns.mappings import (
|
|
get_base_name_to_sets_of_related_ops,
|
|
)
|
|
from torch.quantization.ns.graph_matcher import (
|
|
get_matching_subgraph_pairs,
|
|
get_type_a_related_to_b,
|
|
)
|
|
|
|
from .ns.weight_utils import (
|
|
extract_weight_from_node,
|
|
)
|
|
|
|
from .ns.graph_passes import (
|
|
add_loggers_to_model,
|
|
create_a_shadows_b,
|
|
)
|
|
|
|
from .ns.utils import (
|
|
rekey_logger_info_on_node_name_of_model,
|
|
)
|
|
|
|
from .ns.ns_types import (
|
|
NSSingleResultValuesType,
|
|
NSResultsType,
|
|
NSNodeTargetType,
|
|
)
|
|
|
|
from typing import Dict, Tuple, Callable, List, Optional, Set
|
|
|
|
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
|
|
class OutputLogger(nn.Module):
|
|
stats: List[torch.Tensor]
|
|
stats_rnn: List[RNNReturnType]
|
|
|
|
def __init__(
|
|
self,
|
|
ref_node_name: str,
|
|
prev_node_name: str,
|
|
model_name: str,
|
|
ref_name: str,
|
|
prev_node_target_type: str,
|
|
results_type: str,
|
|
index_within_arg: int,
|
|
index_of_arg: int,
|
|
):
|
|
super().__init__()
|
|
self.stats: List[torch.Tensor] = []
|
|
self.stats_rnn: List[RNNReturnType] = []
|
|
|
|
# name of the node which was responsible for adding this logger
|
|
# Note:
|
|
# - if we are logging node outputs, this is the same as prev_node_name
|
|
# - if we are logging node inputs, this is the name of the node
|
|
# whose input this logger is logging.
|
|
#
|
|
# example, where logger1 is logging input of op1 and logger2 is logging
|
|
# the output of op1:
|
|
#
|
|
# x1 -> logger1 -> op1 -> logger2 -> x2
|
|
#
|
|
# in this example,
|
|
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
|
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
|
self.ref_node_name = ref_node_name
|
|
# name of the node whose output this Logger is capturing
|
|
self.prev_node_name = prev_node_name
|
|
|
|
# name of the model from which the node originated from
|
|
self.model_name = model_name
|
|
# reference name, used to match loggers from separate models
|
|
# to each other
|
|
self.ref_name = ref_name
|
|
# type of the target of the node whose output this logger is logging
|
|
self.prev_node_target_type = prev_node_target_type
|
|
# what kind of values are inside of stats
|
|
self.results_type = results_type
|
|
# index of this node within the arg of the input/output node
|
|
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
|
self.index_within_arg = index_within_arg
|
|
# index of this node within the args of the input/output node
|
|
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
|
self.index_of_arg = index_of_arg
|
|
|
|
# Note: cannot annotate the type of x because TorchScript does not support
|
|
# the Union type.
|
|
def forward(self, x):
|
|
if isinstance(x, torch.Tensor):
|
|
self.stats.append(x.detach())
|
|
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
|
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
|
self.stats_rnn.append(new_res)
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
|
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
|
results_type={self.results_type}, index_within_arg={self.index_within_arg},
|
|
index_of_arg={self.index_of_arg})"""
|
|
|
|
|
|
class NSTracer(quantize_fx.QuantizationTracer):
|
|
"""
|
|
Just like a regular tracer, but treats observers and fake_quantize
|
|
modules as leaf modules.
|
|
"""
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
if isinstance(m, torch.quantization.ObserverBase):
|
|
return True
|
|
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
|
return True
|
|
return super().is_leaf_module(m, module_qualified_name)
|
|
|
|
|
|
def _extract_weights_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
|
results: NSResultsType,
|
|
) -> None:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
for node, ref_name in nodes_and_names_to_instrument:
|
|
res_type = NSSingleResultValuesType.WEIGHT.value
|
|
extracted_weight = \
|
|
extract_weight_from_node(node, model, type_a_related_to_b)
|
|
if extracted_weight:
|
|
if ref_name not in results:
|
|
results[ref_name] = {res_type: {}}
|
|
results[ref_name][res_type][model_name] = [extracted_weight]
|
|
|
|
|
|
def _extract_weights_impl(
|
|
model_name_a: str,
|
|
gm_a: GraphModule,
|
|
model_name_b: str,
|
|
gm_b: GraphModule,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> NSResultsType:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
|
|
# split the subgraph pairs into one data structure for each model
|
|
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
|
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
|
for match_name, match in matched_subgraph_pairs.items():
|
|
subgraph_a, subgraph_b = match
|
|
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
|
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
|
|
|
# populate the results, one model at a time
|
|
results: NSResultsType = {}
|
|
_extract_weights_one_model(
|
|
model_name_a, gm_a, nodes_and_names_to_instrument_a, results)
|
|
_extract_weights_one_model(
|
|
model_name_b, gm_b, nodes_and_names_to_instrument_b, results)
|
|
|
|
# rekey on names of nodes in gm_b
|
|
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
|
|
|
return results
|
|
|
|
|
|
def extract_weights(
|
|
model_name_a: str,
|
|
model_a: nn.Module,
|
|
model_name_b: str,
|
|
model_b: nn.Module,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> NSResultsType:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _extract_weights_impl(
|
|
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
|
|
|
|
def _add_loggers_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str]],
|
|
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str]],
|
|
logger_cls: Callable,
|
|
) -> nn.Module:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
|
|
|
# TODO(future PR): do not observe nodes we do not care
|
|
# about (both fp32, denylist, etc)
|
|
node_to_instrument_inputs_to_ref_name: Dict[Node, str] = {}
|
|
node_to_instrument_outputs_to_ref_name: Dict[Node, str] = {}
|
|
for node, ref_name in nodes_and_names_to_instrument_inputs:
|
|
node_to_instrument_inputs_to_ref_name[node] = ref_name
|
|
for node, ref_name in nodes_and_names_to_instrument_outputs:
|
|
node_to_instrument_outputs_to_ref_name[node] = ref_name
|
|
|
|
model = add_loggers_to_model(
|
|
model, node_to_instrument_inputs_to_ref_name,
|
|
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
|
return model
|
|
|
|
|
|
def _add_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b,
|
|
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
|
nodes_and_names_to_instrument_inputs_a = []
|
|
nodes_and_names_to_instrument_inputs_b = []
|
|
nodes_and_names_to_instrument_outputs_a = []
|
|
nodes_and_names_to_instrument_outputs_b = []
|
|
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
|
# Note: for matching inputs we use start_node, such as observing
|
|
# the input of linear in linear-relu
|
|
if should_log_inputs:
|
|
nodes_and_names_to_instrument_inputs_a.append((subgraph_a.start_node, match_name))
|
|
nodes_and_names_to_instrument_inputs_b.append((subgraph_b.start_node, match_name))
|
|
# Note: for matching activations we always use end_node,
|
|
# such as observing the output of relu in linear-relu
|
|
nodes_and_names_to_instrument_outputs_a.append((subgraph_a.end_node, match_name))
|
|
nodes_and_names_to_instrument_outputs_b.append((subgraph_b.end_node, match_name))
|
|
|
|
new_model_a = _add_loggers_one_model(
|
|
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
|
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
|
new_model_b = _add_loggers_one_model(
|
|
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
|
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
|
return (new_model_a, new_model_b)
|
|
|
|
|
|
def add_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs : bool = False,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _add_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map=unmatchable_types_map)
|
|
|
|
|
|
def _extract_logger_info_one_model(
|
|
model: nn.Module,
|
|
results: NSResultsType,
|
|
logger_cls: Callable,
|
|
) -> None:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
|
for gm_name, mod in model.named_modules():
|
|
# TODO(future PR): better check when scripted
|
|
is_logger = (
|
|
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
|
or (
|
|
isinstance(mod, torch.jit.RecursiveScriptModule)
|
|
and mod.original_name == 'OutputLogger'
|
|
)
|
|
)
|
|
if is_logger:
|
|
key = mod.ref_name
|
|
if key not in results:
|
|
results[key] = {}
|
|
assert mod.model_name not in results[key], \
|
|
f"{mod.model_name} is already present in results"
|
|
if mod.results_type not in results[key]:
|
|
results[key][mod.results_type] = {}
|
|
if mod.model_name not in results[key][mod.results_type]:
|
|
results[key][mod.results_type][mod.model_name] = []
|
|
stats_to_use = mod.stats
|
|
if len(mod.stats_rnn) > 0:
|
|
stats_to_use = mod.stats_rnn
|
|
results[key][mod.results_type][mod.model_name].append({
|
|
'type': mod.results_type,
|
|
'values': stats_to_use,
|
|
'ref_node_name': mod.ref_node_name,
|
|
'prev_node_name': mod.prev_node_name,
|
|
'prev_node_target_type': mod.prev_node_target_type,
|
|
'index_within_arg': mod.index_within_arg,
|
|
'index_of_arg': mod.index_of_arg,
|
|
})
|
|
# ensure the list stays sorted
|
|
results[key][mod.results_type][mod.model_name].sort(
|
|
key=lambda res:
|
|
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
|
)
|
|
|
|
|
|
# TODO(future PR): align on naming
|
|
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
|
def extract_logger_info(
|
|
model_a: nn.Module,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
model_name_to_use_for_layer_names: str,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as ns.extract_logger_info, but for models prepared with
|
|
this module.
|
|
|
|
TODO(future PR): real docblock
|
|
|
|
Output format: NSResultsType
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
|
results: NSResultsType = {}
|
|
for model in (model_a, model_b):
|
|
_extract_logger_info_one_model(model, results, logger_cls)
|
|
# rekey on the name of model b
|
|
results = rekey_logger_info_on_node_name_of_model(
|
|
results, model_name_to_use_for_layer_names)
|
|
return results
|
|
|
|
|
|
def _add_shadow_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> nn.Module:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
gm_a_shadows_b = create_a_shadows_b(
|
|
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
node_type_to_io_type_map=node_type_to_io_type_map)
|
|
return gm_a_shadows_b
|
|
|
|
|
|
def add_shadow_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool = False,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Same thing as add_loggers, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _add_shadow_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
node_type_to_io_type_map=node_type_to_io_type_map,
|
|
unmatchable_types_map=unmatchable_types_map)
|
|
|
|
|
|
def extract_shadow_logger_info(
|
|
model_a_shadows_b: nn.Module,
|
|
logger_cls: Callable,
|
|
model_name_to_use_for_layer_names: str,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
|
results: NSResultsType = collections.defaultdict(dict)
|
|
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
|
# rekey on the name of model b
|
|
results = rekey_logger_info_on_node_name_of_model(
|
|
results, model_name_to_use_for_layer_names)
|
|
return dict(results)
|
|
|
|
|
|
def extend_logger_results_with_comparison(
|
|
results: NSResultsType,
|
|
model_name_1: str,
|
|
model_name_2: str,
|
|
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
comparison_name: str,
|
|
) -> None:
|
|
"""
|
|
Compares the logged values from `model_name_2` against the corresponding
|
|
values in `model_name_1`, using `comparison_fn`. Records the result
|
|
in `model_name_2`'s results under `comparison_name`.
|
|
"""
|
|
for _, results_type_to_results in results.items():
|
|
for _, model_name_to_results in results_type_to_results.items():
|
|
assert model_name_1 in model_name_to_results, \
|
|
f"{model_name_1} not found in results"
|
|
assert model_name_2 in model_name_to_results, \
|
|
f"{model_name_2} not found in results"
|
|
|
|
results_1 = model_name_to_results[model_name_1]
|
|
results_2 = model_name_to_results[model_name_2]
|
|
|
|
for result_2 in results_2:
|
|
index_within_arg_2 = result_2['index_within_arg']
|
|
index_of_arg_2 = result_2['index_of_arg']
|
|
# find corresponding result_1
|
|
result_1 = None
|
|
for cur_result_1 in results_1:
|
|
index_within_arg_1 = cur_result_1['index_within_arg']
|
|
index_of_arg_1 = cur_result_1['index_of_arg']
|
|
if (
|
|
(index_within_arg_1 == index_within_arg_2) and
|
|
(index_of_arg_1 == index_of_arg_2)
|
|
):
|
|
result_1 = cur_result_1
|
|
break
|
|
assert result_1 is not None
|
|
|
|
values_1 = result_1['values']
|
|
values_2 = result_2['values']
|
|
result_2[comparison_name] = []
|
|
for value_1, value_2 in zip(values_1, values_2):
|
|
comparison_result = comparison_fn(value_1, value_2)
|
|
result_2[comparison_name].append(comparison_result)
|