mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 10:01:39 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62038 Updates the logic to extract weights from nodes to use a direct mapping from type to weight extraction function. This is needed for a future PR which will allow users to specify custom weight extraction functions for user defined types. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs python test/test_quantization.py TestFXNumericSuiteCoreAPIsModels ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D29853627 fbshipit-source-id: 3ef90ef4bd7b28f6316c0af215a2bd3ff8a2aeca
493 lines
20 KiB
Python
493 lines
20 KiB
Python
import collections
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.quantization.quantize_fx as quantize_fx
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Node
|
|
from torch.quantization.ns.mappings import (
|
|
get_base_name_to_sets_of_related_ops,
|
|
)
|
|
from torch.quantization.ns.graph_matcher import (
|
|
get_matching_subgraph_pairs,
|
|
get_type_a_related_to_b,
|
|
)
|
|
|
|
from .ns.weight_utils import (
|
|
extract_weight_from_node,
|
|
)
|
|
|
|
from .ns.graph_passes import (
|
|
add_loggers_to_model,
|
|
create_a_shadows_b,
|
|
)
|
|
|
|
from .ns.utils import (
|
|
rekey_logger_info_on_node_name_of_model,
|
|
maybe_add_missing_fqns,
|
|
)
|
|
|
|
from .ns.ns_types import (
|
|
NSSingleResultValuesType,
|
|
NSResultsType,
|
|
NSNodeTargetType,
|
|
)
|
|
|
|
from typing import Dict, Tuple, Callable, List, Optional, Set
|
|
|
|
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
|
|
class OutputLogger(nn.Module):
|
|
stats: List[torch.Tensor]
|
|
stats_rnn: List[RNNReturnType]
|
|
|
|
def __init__(
|
|
self,
|
|
ref_node_name: str,
|
|
prev_node_name: str,
|
|
model_name: str,
|
|
ref_name: str,
|
|
prev_node_target_type: str,
|
|
results_type: str,
|
|
index_within_arg: int,
|
|
index_of_arg: int,
|
|
fqn: Optional[str],
|
|
):
|
|
super().__init__()
|
|
self.stats: List[torch.Tensor] = []
|
|
self.stats_rnn: List[RNNReturnType] = []
|
|
|
|
# name of the node which was responsible for adding this logger
|
|
# Note:
|
|
# - if we are logging node outputs, this is the same as prev_node_name
|
|
# - if we are logging node inputs, this is the name of the node
|
|
# whose input this logger is logging.
|
|
#
|
|
# example, where logger1 is logging input of op1 and logger2 is logging
|
|
# the output of op1:
|
|
#
|
|
# x1 -> logger1 -> op1 -> logger2 -> x2
|
|
#
|
|
# in this example,
|
|
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
|
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
|
self.ref_node_name = ref_node_name
|
|
# name of the node whose output this Logger is capturing
|
|
self.prev_node_name = prev_node_name
|
|
|
|
# name of the model from which the node originated from
|
|
self.model_name = model_name
|
|
# reference name, used to match loggers from separate models
|
|
# to each other
|
|
self.ref_name = ref_name
|
|
# type of the target of the node whose output this logger is logging
|
|
self.prev_node_target_type = prev_node_target_type
|
|
# what kind of values are inside of stats
|
|
self.results_type = results_type
|
|
# index of this node within the arg of the input/output node
|
|
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
|
self.index_within_arg = index_within_arg
|
|
# index of this node within the args of the input/output node
|
|
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
|
self.index_of_arg = index_of_arg
|
|
# fully qualified name
|
|
self.fqn = fqn
|
|
|
|
# Note: cannot annotate the type of x because TorchScript does not support
|
|
# the Union type.
|
|
def forward(self, x):
|
|
if isinstance(x, torch.Tensor):
|
|
self.stats.append(x.detach())
|
|
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
|
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
|
self.stats_rnn.append(new_res)
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
|
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
|
results_type={self.results_type}, index_within_arg={self.index_within_arg},
|
|
index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
|
|
|
|
|
|
class NSTracer(quantize_fx.QuantizationTracer):
|
|
"""
|
|
Just like a regular tracer, but treats observers and fake_quantize
|
|
modules as leaf modules.
|
|
"""
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
if isinstance(m, torch.quantization.ObserverBase):
|
|
return True
|
|
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
|
return True
|
|
return super().is_leaf_module(m, module_qualified_name)
|
|
|
|
|
|
def _extract_weights_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
|
results: NSResultsType,
|
|
) -> None:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
|
for node, ref_name in nodes_and_names_to_instrument:
|
|
res_type = NSSingleResultValuesType.WEIGHT.value
|
|
extracted_weight = extract_weight_from_node(node, model)
|
|
if extracted_weight:
|
|
if ref_name not in results:
|
|
results[ref_name] = {res_type: {}}
|
|
results[ref_name][res_type][model_name] = [extracted_weight]
|
|
|
|
|
|
def _extract_weights_impl(
|
|
model_name_a: str,
|
|
gm_a: GraphModule,
|
|
model_name_b: str,
|
|
gm_b: GraphModule,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> NSResultsType:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
|
|
# split the subgraph pairs into one data structure for each model
|
|
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
|
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
|
for match_name, match in matched_subgraph_pairs.items():
|
|
subgraph_a, subgraph_b = match
|
|
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
|
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
|
|
|
# populate the results, one model at a time
|
|
results: NSResultsType = {}
|
|
_extract_weights_one_model(
|
|
model_name_a, gm_a, nodes_and_names_to_instrument_a, results)
|
|
_extract_weights_one_model(
|
|
model_name_b, gm_b, nodes_and_names_to_instrument_b, results)
|
|
|
|
# fill in missing fqn entries
|
|
maybe_add_missing_fqns(results)
|
|
|
|
# rekey on names of nodes in gm_b
|
|
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
|
|
|
return results
|
|
|
|
|
|
def extract_weights(
|
|
model_name_a: str,
|
|
model_a: nn.Module,
|
|
model_name_b: str,
|
|
model_b: nn.Module,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> NSResultsType:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
if hasattr(model_a, '_node_name_to_scope'):
|
|
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
if hasattr(model_b, '_node_name_to_scope'):
|
|
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
|
return _extract_weights_impl(
|
|
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
|
|
|
|
def _add_loggers_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str]],
|
|
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str]],
|
|
logger_cls: Callable,
|
|
) -> nn.Module:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
|
|
|
# TODO(future PR): do not observe nodes we do not care
|
|
# about (both fp32, denylist, etc)
|
|
node_to_instrument_inputs_to_ref_name: Dict[Node, str] = {}
|
|
node_to_instrument_outputs_to_ref_name: Dict[Node, str] = {}
|
|
for node, ref_name in nodes_and_names_to_instrument_inputs:
|
|
node_to_instrument_inputs_to_ref_name[node] = ref_name
|
|
for node, ref_name in nodes_and_names_to_instrument_outputs:
|
|
node_to_instrument_outputs_to_ref_name[node] = ref_name
|
|
|
|
model = add_loggers_to_model(
|
|
model, node_to_instrument_inputs_to_ref_name,
|
|
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
|
return model
|
|
|
|
|
|
def _add_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b,
|
|
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
|
nodes_and_names_to_instrument_inputs_a = []
|
|
nodes_and_names_to_instrument_inputs_b = []
|
|
nodes_and_names_to_instrument_outputs_a = []
|
|
nodes_and_names_to_instrument_outputs_b = []
|
|
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
|
# Note: for matching inputs we use start_node, such as observing
|
|
# the input of linear in linear-relu
|
|
if should_log_inputs:
|
|
nodes_and_names_to_instrument_inputs_a.append((subgraph_a.start_node, match_name))
|
|
nodes_and_names_to_instrument_inputs_b.append((subgraph_b.start_node, match_name))
|
|
# Note: for matching activations we always use end_node,
|
|
# such as observing the output of relu in linear-relu
|
|
nodes_and_names_to_instrument_outputs_a.append((subgraph_a.end_node, match_name))
|
|
nodes_and_names_to_instrument_outputs_b.append((subgraph_b.end_node, match_name))
|
|
|
|
new_model_a = _add_loggers_one_model(
|
|
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
|
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
|
new_model_b = _add_loggers_one_model(
|
|
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
|
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
|
return (new_model_a, new_model_b)
|
|
|
|
|
|
def add_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs : bool = False,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
if hasattr(model_a, '_node_name_to_scope'):
|
|
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
if hasattr(model_b, '_node_name_to_scope'):
|
|
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
|
return _add_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map=unmatchable_types_map)
|
|
|
|
|
|
def _extract_logger_info_one_model(
|
|
model: nn.Module,
|
|
results: NSResultsType,
|
|
logger_cls: Callable,
|
|
) -> None:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
|
for gm_name, mod in model.named_modules():
|
|
# TODO(future PR): better check when scripted
|
|
is_logger = (
|
|
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
|
or (
|
|
isinstance(mod, torch.jit.RecursiveScriptModule)
|
|
and mod.original_name == 'OutputLogger'
|
|
)
|
|
)
|
|
if is_logger:
|
|
key = mod.ref_name
|
|
if key not in results:
|
|
results[key] = {}
|
|
assert mod.model_name not in results[key], \
|
|
f"{mod.model_name} is already present in results"
|
|
if mod.results_type not in results[key]:
|
|
results[key][mod.results_type] = {}
|
|
if mod.model_name not in results[key][mod.results_type]:
|
|
results[key][mod.results_type][mod.model_name] = []
|
|
stats_to_use = mod.stats
|
|
if len(mod.stats_rnn) > 0:
|
|
stats_to_use = mod.stats_rnn
|
|
results[key][mod.results_type][mod.model_name].append({
|
|
'type': mod.results_type,
|
|
'values': stats_to_use,
|
|
'ref_node_name': mod.ref_node_name,
|
|
'prev_node_name': mod.prev_node_name,
|
|
'prev_node_target_type': mod.prev_node_target_type,
|
|
'index_within_arg': mod.index_within_arg,
|
|
'index_of_arg': mod.index_of_arg,
|
|
'fqn': mod.fqn,
|
|
})
|
|
# ensure the list stays sorted
|
|
results[key][mod.results_type][mod.model_name].sort(
|
|
key=lambda res:
|
|
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
|
)
|
|
|
|
|
|
# TODO(future PR): align on naming
|
|
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
|
def extract_logger_info(
|
|
model_a: nn.Module,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
model_name_to_use_for_layer_names: str,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as ns.extract_logger_info, but for models prepared with
|
|
this module.
|
|
|
|
TODO(future PR): real docblock
|
|
|
|
Output format: NSResultsType
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
|
results: NSResultsType = {}
|
|
for model in (model_a, model_b):
|
|
_extract_logger_info_one_model(model, results, logger_cls)
|
|
# fill in missing fqn entries
|
|
maybe_add_missing_fqns(results)
|
|
# rekey on the name of model b
|
|
results = rekey_logger_info_on_node_name_of_model(
|
|
results, model_name_to_use_for_layer_names)
|
|
return results
|
|
|
|
|
|
def _add_shadow_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> nn.Module:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
gm_a_shadows_b = create_a_shadows_b(
|
|
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
node_type_to_io_type_map=node_type_to_io_type_map)
|
|
return gm_a_shadows_b
|
|
|
|
|
|
def add_shadow_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool = False,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Same thing as add_loggers, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
if hasattr(model_a, '_node_name_to_scope'):
|
|
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
if hasattr(model_b, '_node_name_to_scope'):
|
|
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
|
return _add_shadow_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
node_type_to_io_type_map=node_type_to_io_type_map,
|
|
unmatchable_types_map=unmatchable_types_map)
|
|
|
|
|
|
def extract_shadow_logger_info(
|
|
model_a_shadows_b: nn.Module,
|
|
logger_cls: Callable,
|
|
model_name_to_use_for_layer_names: str,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
|
results: NSResultsType = collections.defaultdict(dict)
|
|
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
|
# fill in missing fqn entries
|
|
maybe_add_missing_fqns(results)
|
|
# rekey on the name of model b
|
|
results = rekey_logger_info_on_node_name_of_model(
|
|
results, model_name_to_use_for_layer_names)
|
|
return dict(results)
|
|
|
|
|
|
def extend_logger_results_with_comparison(
|
|
results: NSResultsType,
|
|
model_name_1: str,
|
|
model_name_2: str,
|
|
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
comparison_name: str,
|
|
) -> None:
|
|
"""
|
|
Compares the logged values from `model_name_2` against the corresponding
|
|
values in `model_name_1`, using `comparison_fn`. Records the result
|
|
in `model_name_2`'s results under `comparison_name`.
|
|
"""
|
|
for _, results_type_to_results in results.items():
|
|
for _, model_name_to_results in results_type_to_results.items():
|
|
assert model_name_1 in model_name_to_results, \
|
|
f"{model_name_1} not found in results"
|
|
assert model_name_2 in model_name_to_results, \
|
|
f"{model_name_2} not found in results"
|
|
|
|
results_1 = model_name_to_results[model_name_1]
|
|
results_2 = model_name_to_results[model_name_2]
|
|
|
|
for result_2 in results_2:
|
|
index_within_arg_2 = result_2['index_within_arg']
|
|
index_of_arg_2 = result_2['index_of_arg']
|
|
# find corresponding result_1
|
|
result_1 = None
|
|
for cur_result_1 in results_1:
|
|
index_within_arg_1 = cur_result_1['index_within_arg']
|
|
index_of_arg_1 = cur_result_1['index_of_arg']
|
|
if (
|
|
(index_within_arg_1 == index_within_arg_2) and
|
|
(index_of_arg_1 == index_of_arg_2)
|
|
):
|
|
result_1 = cur_result_1
|
|
break
|
|
assert result_1 is not None
|
|
|
|
values_1 = result_1['values']
|
|
values_2 = result_2['values']
|
|
result_2[comparison_name] = []
|
|
for value_1, value_2 in zip(values_1, values_2):
|
|
comparison_result = comparison_fn(value_1, value_2)
|
|
result_2[comparison_name].append(comparison_result)
|