mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59353 Next: remove Quantizer class Test Plan: Imported from OSS Reviewed By: raghuramank100 Differential Revision: D28856277 fbshipit-source-id: 25f5502be387dbe9706780f667501b46b82789a5
		
			
				
	
	
		
			60 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			60 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| from collections import OrderedDict
 | |
| from typing import Dict, Any, Tuple, List, Optional
 | |
| from torch.fx.graph import (
 | |
|     Node,
 | |
| )
 | |
| from .quantization_types import Pattern
 | |
| from .qconfig_utils import QConfigAny
 | |
| # from .quantization_patterns import BinaryOpQuantizeHandler
 | |
| 
 | |
| 
 | |
| # TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
 | |
| QuantizeHandler = Any
 | |
| 
 | |
| MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
 | |
|                     QConfigAny]
 | |
| 
 | |
| # pattern for conv bn fusion
 | |
| DEFAULT_FUSION_PATTERNS = OrderedDict()
 | |
| def register_fusion_pattern(pattern):
 | |
|     def insert(fn):
 | |
|         DEFAULT_FUSION_PATTERNS[pattern] = fn
 | |
|         return fn
 | |
|     return insert
 | |
| 
 | |
| def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
 | |
|     return DEFAULT_FUSION_PATTERNS
 | |
| 
 | |
| DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
 | |
| # a map from pattern to activation_post_process(observer/fake_quant) consstructor for output activation
 | |
| # e.g. pattern: torch.sigmoid,
 | |
| #      output_activation_post_process: default_affine_fixed_qparam_fake_quant
 | |
| DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP = dict()
 | |
| 
 | |
| # Register pattern for both static quantization and qat
 | |
| def register_quant_pattern(pattern, output_activation_post_process=None):
 | |
|     def insert(fn):
 | |
|         DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
 | |
|         if output_activation_post_process is not None:
 | |
|             DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP[pattern] = output_activation_post_process
 | |
|         return fn
 | |
|     return insert
 | |
| 
 | |
| # Get patterns for both static quantization and qat
 | |
| def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
 | |
|     return DEFAULT_QUANTIZATION_PATTERNS
 | |
| 
 | |
| # a map from pattern to output activation post process constructor
 | |
| # e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
 | |
| def get_default_output_activation_post_process_map() -> Dict[Pattern, torch.quantization.observer.ObserverBase]:
 | |
|     return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP
 | |
| 
 | |
| 
 | |
| # Example use of register pattern function:
 | |
| # @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
 | |
| # class ConvBNReLUFusion():
 | |
| #     def __init__(...):
 | |
| #         ...
 | |
| #
 |