Files
pytorch/torch/quantization/_numeric_suite_fx.py
Vasiliy Kuznetsov a359cfac22 ns for fx: add option to skip matching classes and functions (#57026)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57026

Adds a config option to skip matching classes by class type
and functions by function type.

This is useful when users make custom modules which return
types other than tensors. With the current implementation of
Logger, these are not scriptable.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_user_module_scriptable
```

Reviewed By: jerryzh168

Differential Revision: D28030093

Pulled By: vkuzo

fbshipit-source-id: 71dc54dd935d2071c4b017260ea2a1e5c2298bfe
2021-04-27 16:29:00 -07:00

400 lines
15 KiB
Python

import collections
import torch
import torch.nn as nn
import torch.quantization.quantize_fx as quantize_fx
from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.quantization.ns.graph_matcher import (
get_matching_subgraph_pairs,
get_base_name_to_sets_of_related_ops,
get_type_a_related_to_b,
)
from .ns.weight_utils import (
extract_weight_from_node,
)
from .ns.graph_passes import (
remove_observers_add_loggers,
create_a_shadows_b,
)
from .ns.ns_types import (
NSSingleResultValuesType,
NSResultsType,
NSNodeTargetType,
)
from typing import Dict, Tuple, Callable, List, Optional, Set
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
class OutputLogger(nn.Module):
stats: List[torch.Tensor]
stats_rnn: List[RNNReturnType]
def __init__(
self,
ref_node_name: str,
prev_node_name: str,
model_name: str,
ref_name: str,
prev_node_target_type: str,
results_type: str,
index_within_arg: int,
index_of_arg: int,
):
super().__init__()
self.stats: List[torch.Tensor] = []
self.stats_rnn: List[RNNReturnType] = []
# name of the node which was responsible for adding this logger
# Note:
# - if we are logging node outputs, this is the same as prev_node_name
# - if we are logging node inputs, this is the name of the node
# whose input this logger is logging.
#
# example, where logger1 is logging input of op1 and logger2 is logging
# the output of op1:
#
# x1 -> logger1 -> op1 -> logger2 -> x2
#
# in this example,
# - logger1's prev_node_name is x1 and ref_node_name is op1
# - logger2's prev_node_name is op1 and ref_node_name is op1
self.ref_node_name = ref_node_name
# name of the node whose output this Logger is capturing
self.prev_node_name = prev_node_name
# name of the model from which the node originated from
self.model_name = model_name
# reference name, used to match loggers from separate models
# to each other
self.ref_name = ref_name
# type of the target of the node whose output this logger is logging
self.prev_node_target_type = prev_node_target_type
# what kind of values are inside of stats
self.results_type = results_type
# index of this node within the arg of the input/output node
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
self.index_within_arg = index_within_arg
# index of this node within the args of the input/output node
# for example, in add(x1, x2), x2 would have index_of_arg == 1
self.index_of_arg = index_of_arg
# Note: cannot annotate the type of x because TorchScript does not support
# the Union type.
def forward(self, x):
if isinstance(x, torch.Tensor):
self.stats.append(x.detach())
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
self.stats_rnn.append(new_res)
return x
def __repr__(self):
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
results_type={self.results_type}, index_within_arg={self.index_within_arg},
index_of_arg={self.index_of_arg})"""
class NSTracer(quantize_fx.QuantizationTracer):
"""
Just like a regular tracer, but treats observers and fake_quantize
modules as leaf modules.
"""
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
if isinstance(m, torch.quantization.ObserverBase):
return True
elif isinstance(m, torch.quantization.FakeQuantizeBase):
return True
return super().is_leaf_module(m, module_qualified_name)
def _extract_weights_one_model(
model_name: str,
model: GraphModule,
nodes_and_names_to_instrument: List[Tuple[Node, str]],
results: NSResultsType,
) -> None:
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
type_a_related_to_b = \
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
for node, ref_name in nodes_and_names_to_instrument:
res_type = NSSingleResultValuesType.WEIGHT.value
if ref_name not in results:
results[ref_name] = {res_type: {}}
extracted_weight = \
extract_weight_from_node(node, model, type_a_related_to_b)
if extracted_weight:
results[ref_name][res_type][model_name] = [extracted_weight]
def _extract_weights_impl(
model_name_a: str,
gm_a: GraphModule,
model_name_b: str,
gm_b: GraphModule,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> NSResultsType:
matched_subgraph_pairs = get_matching_subgraph_pairs(
gm_a, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
# split the subgraph pairs into one data structure for each model
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
for match_name, match in matched_subgraph_pairs.items():
subgraph_a, subgraph_b = match
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
# populate the results, one model at a time
results: NSResultsType = {}
_extract_weights_one_model(
model_name_a, gm_a, nodes_and_names_to_instrument_a, results)
_extract_weights_one_model(
model_name_b, gm_b, nodes_and_names_to_instrument_b, results)
return results
def extract_weights(
model_name_a: str,
model_a: nn.Module,
model_name_b: str,
model_b: nn.Module,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> NSResultsType:
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
type_a_related_to_b = \
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
# TODO(future PR): expose these
skipped_module_names: List[str] = []
skipped_module_classes: List[Callable] = []
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
return _extract_weights_impl(
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
def _add_loggers_one_model(
model_name: str,
model: GraphModule,
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str]],
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str]],
logger_cls: Callable,
) -> nn.Module:
# TODO(future PR): do not observe nodes we do not care
# about (both fp32, denylist, etc)
node_to_instrument_inputs_to_ref_name: Dict[Node, str] = {}
node_to_instrument_outputs_to_ref_name: Dict[Node, str] = {}
for node, ref_name in nodes_and_names_to_instrument_inputs:
node_to_instrument_inputs_to_ref_name[node] = ref_name
for node, ref_name in nodes_and_names_to_instrument_outputs:
node_to_instrument_outputs_to_ref_name[node] = ref_name
model = remove_observers_add_loggers(
model, node_to_instrument_inputs_to_ref_name,
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
return model
def _add_loggers_impl(
name_a: str,
gm_a: GraphModule,
name_b: str,
gm_b: GraphModule,
logger_cls: Callable,
should_log_inputs: bool,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> Tuple[nn.Module, nn.Module]:
matched_subgraph_pairs = get_matching_subgraph_pairs(
gm_a, gm_b,
base_name_to_sets_of_related_ops, unmatchable_types_map)
nodes_and_names_to_instrument_inputs_a = []
nodes_and_names_to_instrument_inputs_b = []
nodes_and_names_to_instrument_outputs_a = []
nodes_and_names_to_instrument_outputs_b = []
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
# Note: for matching inputs we use start_node, such as observing
# the input of linear in linear-relu
if should_log_inputs:
nodes_and_names_to_instrument_inputs_a.append((subgraph_a.start_node, match_name))
nodes_and_names_to_instrument_inputs_b.append((subgraph_b.start_node, match_name))
# Note: for matching activations we always use end_node,
# such as observing the output of relu in linear-relu
nodes_and_names_to_instrument_outputs_a.append((subgraph_a.end_node, match_name))
nodes_and_names_to_instrument_outputs_b.append((subgraph_b.end_node, match_name))
new_model_a = _add_loggers_one_model(
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
nodes_and_names_to_instrument_outputs_a, logger_cls)
new_model_b = _add_loggers_one_model(
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
nodes_and_names_to_instrument_outputs_b, logger_cls)
return (new_model_a, new_model_b)
def add_loggers(
name_a: str,
model_a: nn.Module,
name_b: str,
model_b: nn.Module,
logger_cls: Callable,
should_log_inputs : bool = False,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> Tuple[nn.Module, nn.Module]:
# TODO(future PR): expose these
skipped_module_names: List[str] = []
skipped_module_classes: List[Callable] = []
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
return _add_loggers_impl(
name_a, gm_a, name_b, gm_b, logger_cls,
should_log_inputs=should_log_inputs,
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
unmatchable_types_map=unmatchable_types_map)
def _extract_logger_info_one_model(
model: nn.Module,
results: NSResultsType,
logger_cls: Callable,
) -> None:
for gm_name, mod in model.named_modules():
# TODO(future PR): better check when scripted
is_logger = (
isinstance(mod, logger_cls) # type: ignore[arg-type]
or (
isinstance(mod, torch.jit.RecursiveScriptModule)
and mod.original_name == 'OutputLogger'
)
)
if is_logger:
key = mod.ref_name
if key not in results:
results[key] = {}
assert mod.model_name not in results[key], \
f"{mod.model_name} is already present in results"
if mod.results_type not in results[key]:
results[key][mod.results_type] = {}
if mod.model_name not in results[key][mod.results_type]:
results[key][mod.results_type][mod.model_name] = []
stats_to_use = mod.stats
if len(mod.stats_rnn) > 0:
stats_to_use = mod.stats_rnn
results[key][mod.results_type][mod.model_name].append({
'type': mod.results_type,
'values': stats_to_use,
'ref_node_name': mod.ref_node_name,
'prev_node_name': mod.prev_node_name,
'prev_node_target_type': mod.prev_node_target_type,
'index_within_arg': mod.index_within_arg,
'index_of_arg': mod.index_of_arg,
})
# ensure the list stays sorted
results[key][mod.results_type][mod.model_name].sort(
key=lambda res:
f"{res['index_of_arg']}:{res['index_within_arg']}"
)
# TODO(future PR): align on naming
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
def extract_logger_info(
model_a: nn.Module,
model_b: nn.Module,
logger_cls: Callable,
) -> NSResultsType:
"""
Same thing as ns.extract_logger_info, but for models prepared with
this module.
TODO(future PR): real docblock
Output format: NSResultsType
"""
results: NSResultsType = {}
for model in (model_a, model_b):
_extract_logger_info_one_model(model, results, logger_cls)
return results
def _add_shadow_loggers_impl(
name_a: str,
gm_a: GraphModule,
name_b: str,
gm_b: GraphModule,
logger_cls: Callable,
should_log_inputs: bool,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> nn.Module:
matched_subgraph_pairs = get_matching_subgraph_pairs(
gm_a, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
gm_a_shadows_b = create_a_shadows_b(
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
should_log_inputs=should_log_inputs,
node_type_to_io_type_map=node_type_to_io_type_map)
return gm_a_shadows_b
def add_shadow_loggers(
name_a: str,
model_a: nn.Module,
name_b: str,
model_b: nn.Module,
logger_cls: Callable,
should_log_inputs: bool = False,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> nn.Module:
"""
Same thing as add_loggers, but for an `a_shadows_b` model.
TODO(future PR): real docblock
"""
# TODO(future PR): expose these
skipped_module_names: List[str] = []
skipped_module_classes: List[Callable] = []
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
return _add_shadow_loggers_impl(
name_a, gm_a, name_b, gm_b, logger_cls,
should_log_inputs=should_log_inputs,
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
node_type_to_io_type_map=node_type_to_io_type_map,
unmatchable_types_map=unmatchable_types_map)
def extract_shadow_logger_info(
model_a_shadows_b: nn.Module,
logger_cls: Callable,
) -> NSResultsType:
"""
Same thing as extract_logger_info, but for an `a_shadows_b` model.
TODO(future PR): real docblock
"""
results: NSResultsType = collections.defaultdict(dict)
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
return dict(results)