Add Custom Module Support List (#82606)

Summary:
Add a global custon module support list  for the users to specify the modules they want the equalization process support.

To use this list, import it from the _equalize.py file and append module in it.

Unittest passed to check global support list:

https://pxl.cl/28RKG

Test Plan: buck1 test mode/dev //on_device_ai/odai/tests/transforms:test_transforms -- --exact 'on_device_ai/odai/tests/transforms:test_transforms - test_custom_support_list (on_device_ai.odai.tests.transforms.test_input_weight_for_turing.TestInputWeight)'

Reviewed By: jerryzh168

Differential Revision: D38264244

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82606
Approved by: https://github.com/HDCharles
This commit is contained in:
Hao Li
2022-08-03 17:48:51 +00:00
committed by PyTorch MergeBot
parent 9d228fe517
commit aa40503954
2 changed files with 21 additions and 17 deletions

View File

@ -1,3 +1,8 @@
import warnings
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -5,25 +10,16 @@ import torch.nn.intrinsic as nni
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.fx.graph import Node from torch.fx.graph import Node
from ..observer import _with_args, ObserverBase, PerChannelMinMaxObserver
from ..utils import _parent_name, check_min_max_valid
from .utils import ( from .utils import (
WEIGHT_INDEX_DICT,
get_new_attr_name_with_prefix, get_new_attr_name_with_prefix,
maybe_get_next_module, maybe_get_next_module,
) WEIGHT_INDEX_DICT,
from ..observer import (
PerChannelMinMaxObserver,
_with_args,
ObserverBase,
)
from ..utils import (
check_min_max_valid,
_parent_name,
) )
from collections import namedtuple CUSTOM_MODULE_SUPP_LIST: List[Any] = []
from typing import Dict, Any, List, Tuple, Optional
import warnings
def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor: def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
"""Reshapes the scale so that we can multiply it to the input by the given axis. """Reshapes the scale so that we can multiply it to the input by the given axis.
@ -241,13 +237,19 @@ def nn_module_supports_equalization(module) -> bool:
""" Checks if the torch.nn node supports equalization. """ """ Checks if the torch.nn node supports equalization. """
return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d] return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
def custom_module_supports_equalization(module) -> bool:
""" Checks if the custom node supports equalization. """
return type(module) in CUSTOM_MODULE_SUPP_LIST
def node_supports_equalization(node: Node, modules) -> bool: def node_supports_equalization(node: Node, modules) -> bool:
""" Checks if the current node supports equalization """ Checks if the current node supports equalization
Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
""" """
if node.op == 'call_module': if node.op == 'call_module':
return nn_module_supports_equalization(modules[str(node.target)]) or \ return nn_module_supports_equalization(modules[str(node.target)]) or \
fused_module_supports_equalization(modules[str(node.target)]) fused_module_supports_equalization(modules[str(node.target)]) or \
custom_module_supports_equalization(modules[str(node.target)])
elif node.op == 'call_function': elif node.op == 'call_function':
return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d] return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
return False return False
@ -413,7 +415,7 @@ def scale_weight_node(
op_module = modules[str(node.target)][0] # type: ignore[index] op_module = modules[str(node.target)][0] # type: ignore[index]
else: else:
op_module = modules[str(node.target)] op_module = modules[str(node.target)]
assert(nn_module_supports_equalization(op_module)) assert(nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module))
# Scale the weights for input-weight equalization # Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale # If the following layer needs to be equalized then we will multiply its scale

View File

@ -17,6 +17,7 @@ from torch.ao.quantization.fx._equalize import (
default_equalization_qconfig, default_equalization_qconfig,
fused_module_supports_equalization, fused_module_supports_equalization,
nn_module_supports_equalization, nn_module_supports_equalization,
custom_module_supports_equalization,
node_supports_equalization, node_supports_equalization,
is_equalization_observer, is_equalization_observer,
get_op_node_and_weight_eq_obs, get_op_node_and_weight_eq_obs,
@ -32,5 +33,6 @@ from torch.ao.quantization.fx._equalize import (
convert_eq_obs, convert_eq_obs,
_convert_equalization_ref, _convert_equalization_ref,
get_layer_sqnr_dict, get_layer_sqnr_dict,
get_equalization_qconfig_dict get_equalization_qconfig_dict,
CUSTOM_MODULE_SUPP_LIST,
) )