mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	take2: Docstring only changes in quantization, fake_quantize, and observer (#27574)
* docstring only formatting changes in the quantize.py and fake_quantization.py files to render better in HTML. * docstring change on observer.py as well * just kind of tweaking the docstrings a bit more. * switching to r""" for the mult-line string. Per Zafar's suggestion. * trying to resolve the merge conflict soumith saw * trying to avoid a conflict when this gets merged back to master
This commit is contained in:
		
				
					committed by
					
						
						Soumith Chintala
					
				
			
			
				
	
			
			
			
						parent
						
							fb489555a9
						
					
				
				
					commit
					f0d3fc70b4
				
			@ -4,13 +4,11 @@ from torch.nn import Module
 | 
			
		||||
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
 | 
			
		||||
 | 
			
		||||
class FakeQuantize(Module):
 | 
			
		||||
    ''' Simulate the quantize and dequantize operations in training time.
 | 
			
		||||
    r""" Simulate the quantize and dequantize operations in training time.
 | 
			
		||||
    The output of this module is given by
 | 
			
		||||
 | 
			
		||||
    x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    * :attr:`scale` defines the scale factor used for quantization.
 | 
			
		||||
 | 
			
		||||
    * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
 | 
			
		||||
@ -25,13 +23,13 @@ class FakeQuantize(Module):
 | 
			
		||||
    * :attr:`observer_enable` controls statistics collection on tensors
 | 
			
		||||
 | 
			
		||||
    * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
 | 
			
		||||
     allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should
 | 
			
		||||
     be chosen to be consistent with the dtype
 | 
			
		||||
                    allowable values are torch.qint8 and torch.quint8. The values of quant_min and 
 | 
			
		||||
                    quant_max should be chosen to be consistent with the dtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        observer (module): Module for observing statistics on input tensors and calculating
 | 
			
		||||
        scale and zero-point.
 | 
			
		||||
        observer (module): Module for observing statistics on input tensors and calculating scale 
 | 
			
		||||
                           and zero-point.
 | 
			
		||||
        quant_min (int): The minimum allowable quantized value.
 | 
			
		||||
        quant_max (int): The maximum allowable quantized value.
 | 
			
		||||
        observer_kwargs (optional): Arguments for the observer module
 | 
			
		||||
@ -39,15 +37,7 @@ class FakeQuantize(Module):
 | 
			
		||||
    Attributes:
 | 
			
		||||
        observer (Module): User provided module that collects statistics on the input tensor and
 | 
			
		||||
                           provides a method to calculate scale and zero-point.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    Args:
 | 
			
		||||
        `observer`: Observer module that records stats of input tensor
 | 
			
		||||
        `quant_min`: Tensors are fake-quantized corresponding to the
 | 
			
		||||
        `quant_max`: A function that calculates quantization parameters
 | 
			
		||||
        given the stats
 | 
			
		||||
        `observer_kwargs`
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
 | 
			
		||||
        super(FakeQuantize, self).__init__()
 | 
			
		||||
        assert quant_min <= quant_max, \
 | 
			
		||||
 | 
			
		||||
@ -174,7 +174,7 @@ class MinMaxObserver(_ObserverBase):
 | 
			
		||||
    r"""Default Observer Module
 | 
			
		||||
    A default implementation of the observer module, only works for
 | 
			
		||||
    `per_tensor_affine` quantization scheme.  The module will record the
 | 
			
		||||
     running average of max and min value of the observed Tensor and
 | 
			
		||||
    running average of max and min value of the observed Tensor and
 | 
			
		||||
    calculate_qparams will calculate scale and zero_point
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -80,12 +80,10 @@ def add_observer_(module):
 | 
			
		||||
    has a valid qconfig attribute.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        module: input module with qconfig attributes for all the leaf modules
 | 
			
		||||
        that we want to quantize
 | 
			
		||||
        module: input module with qconfig attributes for all the leaf modules that we want to quantize
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        None, module is modified inplace with added observer modules and
 | 
			
		||||
            forward_hooks
 | 
			
		||||
        None, module is modified inplace with added observer modules and forward_hooks
 | 
			
		||||
    """
 | 
			
		||||
    for child in module.children():
 | 
			
		||||
        if type(child) == nnq.FloatFunctional:
 | 
			
		||||
@ -109,7 +107,7 @@ def add_quant_dequant(module):
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        module: input module with qconfig attributes for all the leaf modules
 | 
			
		||||
        that we want to quantize
 | 
			
		||||
                that we want to quantize
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        Either the inplace modified module with submodules wrapped in
 | 
			
		||||
@ -252,7 +250,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
 | 
			
		||||
    Args:
 | 
			
		||||
        model: input model
 | 
			
		||||
        run_fn: a function for evaluating the prepared model, can be a
 | 
			
		||||
            function that simply runs the prepared model or a training loop
 | 
			
		||||
                function that simply runs the prepared model or a training 
 | 
			
		||||
                loop
 | 
			
		||||
        run_args: positional arguments for `run_fn`
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
@ -269,10 +268,11 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
 | 
			
		||||
def convert(module, mapping=None, inplace=False):
 | 
			
		||||
    r"""Converts the float module with observers (where we can get quantization
 | 
			
		||||
    parameters) to a quantized module.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        module: calibrated module with observers
 | 
			
		||||
        mapping: a dictionary that maps from float module type to quantized
 | 
			
		||||
           module type, can be overwrritten to allow swapping user defined Modules
 | 
			
		||||
        mapping: a dictionary that maps from float module type to quantized module type, can 
 | 
			
		||||
                 be overwrritten to allow swapping user defined Modules
 | 
			
		||||
        inplace: carry out model transformations in-place, the original module is mutated
 | 
			
		||||
    """
 | 
			
		||||
    if mapping is None:
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user