mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add Custom Module Support List (#82606)
Summary: Add a global custon module support list for the users to specify the modules they want the equalization process support. To use this list, import it from the _equalize.py file and append module in it. Unittest passed to check global support list: https://pxl.cl/28RKG Test Plan: buck1 test mode/dev //on_device_ai/odai/tests/transforms:test_transforms -- --exact 'on_device_ai/odai/tests/transforms:test_transforms - test_custom_support_list (on_device_ai.odai.tests.transforms.test_input_weight_for_turing.TestInputWeight)' Reviewed By: jerryzh168 Differential Revision: D38264244 Pull Request resolved: https://github.com/pytorch/pytorch/pull/82606 Approved by: https://github.com/HDCharles
This commit is contained in:
@ -1,3 +1,8 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -5,25 +10,16 @@ import torch.nn.intrinsic as nni
|
|||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.fx.graph import Node
|
from torch.fx.graph import Node
|
||||||
|
|
||||||
|
from ..observer import _with_args, ObserverBase, PerChannelMinMaxObserver
|
||||||
|
from ..utils import _parent_name, check_min_max_valid
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
WEIGHT_INDEX_DICT,
|
|
||||||
get_new_attr_name_with_prefix,
|
get_new_attr_name_with_prefix,
|
||||||
maybe_get_next_module,
|
maybe_get_next_module,
|
||||||
)
|
WEIGHT_INDEX_DICT,
|
||||||
from ..observer import (
|
|
||||||
PerChannelMinMaxObserver,
|
|
||||||
_with_args,
|
|
||||||
ObserverBase,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
check_min_max_valid,
|
|
||||||
_parent_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from collections import namedtuple
|
CUSTOM_MODULE_SUPP_LIST: List[Any] = []
|
||||||
from typing import Dict, Any, List, Tuple, Optional
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
|
def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
|
||||||
"""Reshapes the scale so that we can multiply it to the input by the given axis.
|
"""Reshapes the scale so that we can multiply it to the input by the given axis.
|
||||||
@ -241,13 +237,19 @@ def nn_module_supports_equalization(module) -> bool:
|
|||||||
""" Checks if the torch.nn node supports equalization. """
|
""" Checks if the torch.nn node supports equalization. """
|
||||||
return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
|
return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
|
||||||
|
|
||||||
|
def custom_module_supports_equalization(module) -> bool:
|
||||||
|
""" Checks if the custom node supports equalization. """
|
||||||
|
return type(module) in CUSTOM_MODULE_SUPP_LIST
|
||||||
|
|
||||||
|
|
||||||
def node_supports_equalization(node: Node, modules) -> bool:
|
def node_supports_equalization(node: Node, modules) -> bool:
|
||||||
""" Checks if the current node supports equalization
|
""" Checks if the current node supports equalization
|
||||||
Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
|
Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
|
||||||
"""
|
"""
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
return nn_module_supports_equalization(modules[str(node.target)]) or \
|
return nn_module_supports_equalization(modules[str(node.target)]) or \
|
||||||
fused_module_supports_equalization(modules[str(node.target)])
|
fused_module_supports_equalization(modules[str(node.target)]) or \
|
||||||
|
custom_module_supports_equalization(modules[str(node.target)])
|
||||||
elif node.op == 'call_function':
|
elif node.op == 'call_function':
|
||||||
return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
|
return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||||
return False
|
return False
|
||||||
@ -413,7 +415,7 @@ def scale_weight_node(
|
|||||||
op_module = modules[str(node.target)][0] # type: ignore[index]
|
op_module = modules[str(node.target)][0] # type: ignore[index]
|
||||||
else:
|
else:
|
||||||
op_module = modules[str(node.target)]
|
op_module = modules[str(node.target)]
|
||||||
assert(nn_module_supports_equalization(op_module))
|
assert(nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module))
|
||||||
|
|
||||||
# Scale the weights for input-weight equalization
|
# Scale the weights for input-weight equalization
|
||||||
# If the following layer needs to be equalized then we will multiply its scale
|
# If the following layer needs to be equalized then we will multiply its scale
|
||||||
|
@ -17,6 +17,7 @@ from torch.ao.quantization.fx._equalize import (
|
|||||||
default_equalization_qconfig,
|
default_equalization_qconfig,
|
||||||
fused_module_supports_equalization,
|
fused_module_supports_equalization,
|
||||||
nn_module_supports_equalization,
|
nn_module_supports_equalization,
|
||||||
|
custom_module_supports_equalization,
|
||||||
node_supports_equalization,
|
node_supports_equalization,
|
||||||
is_equalization_observer,
|
is_equalization_observer,
|
||||||
get_op_node_and_weight_eq_obs,
|
get_op_node_and_weight_eq_obs,
|
||||||
@ -32,5 +33,6 @@ from torch.ao.quantization.fx._equalize import (
|
|||||||
convert_eq_obs,
|
convert_eq_obs,
|
||||||
_convert_equalization_ref,
|
_convert_equalization_ref,
|
||||||
get_layer_sqnr_dict,
|
get_layer_sqnr_dict,
|
||||||
get_equalization_qconfig_dict
|
get_equalization_qconfig_dict,
|
||||||
|
CUSTOM_MODULE_SUPP_LIST,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user