mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
take2: Docstring only changes in quantization, fake_quantize, and observer (#27574)
* docstring only formatting changes in the quantize.py and fake_quantization.py files to render better in HTML. * docstring change on observer.py as well * just kind of tweaking the docstrings a bit more. * switching to r""" for the mult-line string. Per Zafar's suggestion. * trying to resolve the merge conflict soumith saw * trying to avoid a conflict when this gets merged back to master
This commit is contained in:
committed by
Soumith Chintala
parent
fb489555a9
commit
f0d3fc70b4
@ -4,13 +4,11 @@ from torch.nn import Module
|
||||
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
|
||||
|
||||
class FakeQuantize(Module):
|
||||
''' Simulate the quantize and dequantize operations in training time.
|
||||
r""" Simulate the quantize and dequantize operations in training time.
|
||||
The output of this module is given by
|
||||
|
||||
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
|
||||
|
||||
|
||||
|
||||
* :attr:`scale` defines the scale factor used for quantization.
|
||||
|
||||
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
|
||||
@ -25,13 +23,13 @@ class FakeQuantize(Module):
|
||||
* :attr:`observer_enable` controls statistics collection on tensors
|
||||
|
||||
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
|
||||
allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should
|
||||
be chosen to be consistent with the dtype
|
||||
allowable values are torch.qint8 and torch.quint8. The values of quant_min and
|
||||
quant_max should be chosen to be consistent with the dtype
|
||||
|
||||
|
||||
Args:
|
||||
observer (module): Module for observing statistics on input tensors and calculating
|
||||
scale and zero-point.
|
||||
observer (module): Module for observing statistics on input tensors and calculating scale
|
||||
and zero-point.
|
||||
quant_min (int): The minimum allowable quantized value.
|
||||
quant_max (int): The maximum allowable quantized value.
|
||||
observer_kwargs (optional): Arguments for the observer module
|
||||
@ -39,15 +37,7 @@ class FakeQuantize(Module):
|
||||
Attributes:
|
||||
observer (Module): User provided module that collects statistics on the input tensor and
|
||||
provides a method to calculate scale and zero-point.
|
||||
|
||||
"""
|
||||
Args:
|
||||
`observer`: Observer module that records stats of input tensor
|
||||
`quant_min`: Tensors are fake-quantized corresponding to the
|
||||
`quant_max`: A function that calculates quantization parameters
|
||||
given the stats
|
||||
`observer_kwargs`
|
||||
'''
|
||||
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
|
||||
super(FakeQuantize, self).__init__()
|
||||
assert quant_min <= quant_max, \
|
||||
|
||||
@ -80,12 +80,10 @@ def add_observer_(module):
|
||||
has a valid qconfig attribute.
|
||||
|
||||
Args:
|
||||
module: input module with qconfig attributes for all the leaf modules
|
||||
that we want to quantize
|
||||
module: input module with qconfig attributes for all the leaf modules that we want to quantize
|
||||
|
||||
Return:
|
||||
None, module is modified inplace with added observer modules and
|
||||
forward_hooks
|
||||
None, module is modified inplace with added observer modules and forward_hooks
|
||||
"""
|
||||
for child in module.children():
|
||||
if type(child) == nnq.FloatFunctional:
|
||||
@ -252,7 +250,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
|
||||
Args:
|
||||
model: input model
|
||||
run_fn: a function for evaluating the prepared model, can be a
|
||||
function that simply runs the prepared model or a training loop
|
||||
function that simply runs the prepared model or a training
|
||||
loop
|
||||
run_args: positional arguments for `run_fn`
|
||||
|
||||
Return:
|
||||
@ -269,10 +268,11 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
|
||||
def convert(module, mapping=None, inplace=False):
|
||||
r"""Converts the float module with observers (where we can get quantization
|
||||
parameters) to a quantized module.
|
||||
|
||||
Args:
|
||||
module: calibrated module with observers
|
||||
mapping: a dictionary that maps from float module type to quantized
|
||||
module type, can be overwrritten to allow swapping user defined Modules
|
||||
mapping: a dictionary that maps from float module type to quantized module type, can
|
||||
be overwrritten to allow swapping user defined Modules
|
||||
inplace: carry out model transformations in-place, the original module is mutated
|
||||
"""
|
||||
if mapping is None:
|
||||
|
||||
Reference in New Issue
Block a user