mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Benchmark: NVIDIA GTX 1650 + AMD Ryzen Threadripper 3970X ```python import torch print(torch.__version__) for i in range(1000): torch.randn(1024 * 128, device='cuda') def cuda(e): a = torch.randn(2 ** e, 32, device='cuda') s = torch.randn(32, device='cuda') z = torch.randn(32, device='cuda') torch.cuda.synchronize() %timeit torch.fake_quantize_per_channel_affine(a, s, z, 1, -999, 999); torch.cuda.synchronize() def cpu(e): a = torch.randn(2 ** e, 32, device='cpu') s = torch.randn(32, device='cpu') z = torch.randn(32, device='cpu') %timeit torch.fake_quantize_per_channel_affine(a, s, z, 1, -999, 999); for i in range(10, 24): cuda(i) print() for i in range(10, 32): cpu(i) ``` Before ``` 1.5.0a0+9bc922d 849 µs ± 44.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 817 µs ± 30.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 814 µs ± 2.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.11 ms ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.19 ms ± 4.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.6 ms ± 5.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.44 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 4.14 ms ± 2.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.41 ms ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 13.9 ms ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 26.9 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 52.6 ms ± 260 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 104 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 207 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 249 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 420 µs ± 230 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 766 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.45 ms ± 574 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.84 ms ± 34.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 5.69 ms ± 83 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.29 ms ± 2.58 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.32 ms ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 17.4 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 47.5 ms ± 264 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 187 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 379 ms ± 5.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 652 ms ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.22 s ± 4.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 2.34 s ± 8.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 4.56 s ± 7.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 8.97 s ± 33.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 17.8 s ± 32.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 35.2 s ± 167 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` After ``` 1.5.0a0+a7ec8cc 92.5 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 97.7 µs ± 469 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 109 µs ± 4.73 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 119 µs ± 6.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 146 µs ± 1.84 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 211 µs ± 2.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 347 µs ± 4.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 624 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.17 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.25 ms ± 48.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 4.43 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 8.51 ms ± 44.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 16.9 ms ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 33.7 ms ± 7.64 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 201 µs ± 234 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 285 µs ± 465 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 287 µs ± 214 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 287 µs ± 221 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 287 µs ± 761 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 347 µs ± 399 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 675 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.34 ms ± 643 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 4.82 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 10.7 ms ± 88.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 20.3 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 39.4 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 78.8 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 153 ms ± 786 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 285 ms ± 911 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 541 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.03 s ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.97 s ± 8.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 3.81 s ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` Fixes https://github.com/pytorch/pytorch/issues/33647 Pull Request resolved: https://github.com/pytorch/pytorch/pull/33772 Differential Revision: D20112531 Pulled By: ngimel fbshipit-source-id: f90e3ef1b5be8276851637f3e1251cb8f1af411f
156 lines
7.5 KiB
Python
156 lines
7.5 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
import torch
|
|
from torch.nn import Module
|
|
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
|
|
|
|
class FakeQuantize(Module):
|
|
r""" Simulate the quantize and dequantize operations in training time.
|
|
The output of this module is given by
|
|
|
|
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
|
|
|
|
|
|
|
|
* :attr:`scale` defines the scale factor used for quantization.
|
|
|
|
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
|
|
|
|
* :attr:`quant_min` specifies the minimum allowable quantized value.
|
|
|
|
* :attr:`quant_max` specifies the maximum allowable quantized value.
|
|
|
|
* :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
|
|
statistics can still be updated.
|
|
|
|
* :attr:`observer_enable` controls statistics collection on tensors
|
|
|
|
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
|
|
allowable values are torch.qint8 and torch.quint8. The values of quant_min and
|
|
quant_max should be chosen to be consistent with the dtype
|
|
|
|
|
|
Args:
|
|
observer (module): Module for observing statistics on input tensors and calculating scale
|
|
and zero-point.
|
|
quant_min (int): The minimum allowable quantized value.
|
|
quant_max (int): The maximum allowable quantized value.
|
|
observer_kwargs (optional): Arguments for the observer module
|
|
|
|
Attributes:
|
|
observer (Module): User provided module that collects statistics on the input tensor and
|
|
provides a method to calculate scale and zero-point.
|
|
|
|
"""
|
|
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
|
|
super(FakeQuantize, self).__init__()
|
|
assert quant_min <= quant_max, \
|
|
'quant_min must be less than or equal to quant_max'
|
|
self.quant_min = quant_min
|
|
self.quant_max = quant_max
|
|
self.fake_quant_enabled = True
|
|
self.observer_enabled = True
|
|
self.activation_post_process = observer(**observer_kwargs)
|
|
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
|
|
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
|
|
self.register_buffer('scale', torch.tensor([1.0]))
|
|
self.register_buffer('zero_point', torch.tensor([0]))
|
|
self.dtype = self.activation_post_process.dtype
|
|
self.qscheme = self.activation_post_process.qscheme
|
|
self.ch_axis = self.activation_post_process.ch_axis if hasattr(self.activation_post_process, 'ch_axis') else None
|
|
|
|
def enable_fake_quant(self, enabled=True):
|
|
self.fake_quant_enabled = enabled
|
|
return self
|
|
|
|
def disable_fake_quant(self):
|
|
return self.enable_fake_quant(False)
|
|
|
|
def enable_observer(self, enabled=True):
|
|
self.observer_enabled = enabled
|
|
return self
|
|
|
|
def disable_observer(self):
|
|
return self.enable_observer(False)
|
|
|
|
def calculate_qparams(self):
|
|
return self.activation_post_process.calculate_qparams()
|
|
|
|
def forward(self, X):
|
|
if self.observer_enabled:
|
|
self.activation_post_process(X.detach())
|
|
_scale, _zero_point = self.calculate_qparams()
|
|
self.scale, self.zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
|
|
if self.fake_quant_enabled:
|
|
if self.qscheme == torch.per_channel_symmetric or self.qscheme == torch.per_channel_affine:
|
|
X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point,
|
|
self.ch_axis, self.quant_min, self.quant_max)
|
|
else:
|
|
X = torch.fake_quantize_per_tensor_affine(X, float(self.scale),
|
|
int(self.zero_point), self.quant_min,
|
|
self.quant_max)
|
|
return X
|
|
|
|
with_args = classmethod(_with_args)
|
|
|
|
def extra_repr(self):
|
|
return 'fake_quant_enabled={}, observer_enabled={},\
|
|
scale={}, zero_point={}'.format(
|
|
self.fake_quant_enabled, self.observer_enabled,
|
|
self.scale, self.zero_point)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
# We cannot currently register scalar values as buffers, so need to manually
|
|
# specify serialization here.
|
|
super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'scale'] = self.scale
|
|
destination[prefix + 'zero_point'] = self.zero_point
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
# Removing this function throws an error that the the size of the loaded tensor does not match the original size
|
|
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
|
|
local_state = ['scale', 'zero_point']
|
|
for name in local_state:
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
val = state_dict[key]
|
|
setattr(self, name, val)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
|
|
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
|
|
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
|
|
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
|
|
|
|
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_channel_symmetric,
|
|
reduce_range=False,
|
|
ch_axis=0)
|
|
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
dtype=torch.quint8,
|
|
qscheme=torch.per_tensor_affine,
|
|
reduce_range=True)
|
|
def disable_fake_quant(mod):
|
|
if type(mod) == FakeQuantize:
|
|
mod.disable_fake_quant()
|
|
|
|
def enable_fake_quant(mod):
|
|
if type(mod) == FakeQuantize:
|
|
mod.enable_fake_quant()
|
|
|
|
def disable_observer(mod):
|
|
if type(mod) == FakeQuantize:
|
|
mod.disable_observer()
|
|
|
|
def enable_observer(mod):
|
|
if type(mod) == FakeQuantize:
|
|
mod.enable_observer()
|