mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66484 https://github.com/pytorch/pytorch/pull/50748 added linear - bn1d fusion in Eager mode, for PTQ only. This PR also enables this in FX graph mode. We reuse the existing conv-bn-relu fusion handler, renaming `conv` to `conv_or_linear` for readability. The QAT version is saved for a future PR, for both eager and FX graph. Test Plan: ``` python test/test_quantization.py TestFuseFx.test_fuse_linear_bn_eval ``` Imported from OSS Reviewed By: bdhirsh Differential Revision: D31575392 fbshipit-source-id: f69d80ef37c98cbc070099170e335e250bcdf913
143 lines
6.9 KiB
Python
143 lines
6.9 KiB
Python
import torch
|
|
from torch.fx.graph import Node
|
|
from .pattern_utils import (
|
|
register_fusion_pattern,
|
|
)
|
|
from .utils import _parent_name
|
|
from .quantization_types import QuantizerCls
|
|
from ..fuser_method_mappings import get_fuser_method
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Callable, Dict
|
|
|
|
# ---------------------
|
|
# Fusion Pattern Registrations
|
|
# ---------------------
|
|
|
|
# Base Pattern Handler
|
|
class FuseHandler(ABC):
|
|
""" Base handler class for the fusion patterns
|
|
"""
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
|
|
fuse_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
pass
|
|
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
|
|
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
|
|
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
|
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
|
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
|
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
|
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
|
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
|
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
|
|
class ConvOrLinearBNReLUFusion(FuseHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = None
|
|
self.bn_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
|
|
self.relu_node = node
|
|
assert isinstance(node.args[0], Node)
|
|
node = node.args[0]
|
|
assert node.op == 'call_module'
|
|
if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]:
|
|
self.bn_node = node
|
|
self.bn = quantizer.modules[self.bn_node.target]
|
|
assert isinstance(node.args[0], Node)
|
|
node = node.args[0]
|
|
assert node.op == 'call_module'
|
|
self.conv_or_linear_node = node
|
|
self.conv_or_linear = quantizer.modules[self.conv_or_linear_node.target]
|
|
|
|
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
|
|
fuse_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if fuse_custom_config_dict is None:
|
|
fuse_custom_config_dict = {}
|
|
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
|
op_list = []
|
|
if self.relu_node is not None:
|
|
# since relu can be used multiple times, we'll need to create a relu module for each match
|
|
if self.relu_node.op == 'call_module':
|
|
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
|
|
else:
|
|
# TODO: get inplace argument from functional
|
|
relu = torch.nn.ReLU()
|
|
op_list.append(relu)
|
|
relu.training = self.conv_or_linear.training
|
|
if self.bn_node is not None:
|
|
op_list.append(self.bn)
|
|
op_list.append(self.conv_or_linear)
|
|
else:
|
|
assert self.bn_node is not None
|
|
op_list.append(self.bn)
|
|
op_list.append(self.conv_or_linear)
|
|
|
|
# the modules are added in order of relu - bn - conv_or_linear
|
|
# so we need to correct it
|
|
op_list.reverse()
|
|
op_type_list = tuple(type(m) for m in op_list)
|
|
conv_or_linear_parent_name, conv_or_linear_name = _parent_name(self.conv_or_linear_node.target)
|
|
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
|
|
if fuser_method is None:
|
|
raise NotImplementedError("Cannot fuse modules: {}".format(op_type_list))
|
|
fused = fuser_method(*op_list)
|
|
setattr(quantizer.modules[conv_or_linear_parent_name], conv_or_linear_name, fused)
|
|
|
|
# TODO: do we need to make sure bn is only used once?
|
|
if self.bn_node is not None:
|
|
parent_name, name = _parent_name(self.bn_node.target)
|
|
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
|
|
# relu may be used multiple times, so we don't set relu to identity
|
|
return quantizer.fused_graph.node_copy(self.conv_or_linear_node, load_arg)
|
|
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm2d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d))
|
|
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d))
|
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d))
|
|
class ModuleReLUFusion(FuseHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = node
|
|
assert isinstance(node.args[0], Node)
|
|
node = node.args[0]
|
|
assert node.op == 'call_module'
|
|
self.module_node = node
|
|
self.module = quantizer.modules[self.module_node.target]
|
|
|
|
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
|
|
fuse_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if fuse_custom_config_dict is None:
|
|
fuse_custom_config_dict = {}
|
|
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
|
op_list = []
|
|
# since relu can be used multiple times, we'll need to create a relu module for each match
|
|
if self.relu_node.op == 'call_module':
|
|
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
|
|
else:
|
|
# TODO: get inplace argument from functional
|
|
relu = torch.nn.ReLU()
|
|
relu.training = self.module.training
|
|
op_list.append(relu)
|
|
op_list.append(self.module)
|
|
|
|
op_list.reverse()
|
|
op_type_list = tuple(type(m) for m in op_list)
|
|
module_parent_name, module_name = _parent_name(self.module_node.target)
|
|
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
|
|
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
|
|
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
|