mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: noticed that context quantized_engine is accessed and written from multiple threads Test Plan: ➜ fbsource buck test --flagfile fbcode/mode/dev-tsan //xplat/assistant/integration_test/tests/supernova/speechtranslation:live_speech_translation_en_fr_tests -- --exact 'fbsource//xplat/assistant/integration_test/tests/supernova/speechtranslation:live_speech_translation_en_fr_tests - Translate/LiveSpeechTranslationTests.LiveSpeechTranslationEnFr/silence___fr_en' Rollback Plan: Differential Revision: D76921416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156378 Approved by: https://github.com/jerryzh168, https://github.com/cyyever
482 lines
17 KiB
Python
482 lines
17 KiB
Python
# mypy: ignore-errors
|
|
|
|
r"""Importing this file includes common utility methods for checking quantized
|
|
tensors and modules.
|
|
"""
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor
|
|
from contextlib import contextmanager
|
|
from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS
|
|
|
|
supported_qengines = torch.backends.quantized.supported_engines
|
|
# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
|
|
# QNNPACK is not supported on PPC
|
|
if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]):
|
|
supported_qengines.remove('qnnpack')
|
|
|
|
def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
|
|
output_padding=0):
|
|
"""Computes the output shape given convolution parameters."""
|
|
return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
|
|
* (dilation - 1)) / stride) + 2 * output_padding + 1
|
|
|
|
# Quantization references
|
|
def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
|
|
"""Quantizes a numpy array."""
|
|
if qmin is None:
|
|
qmin = np.iinfo(dtype).min
|
|
if qmax is None:
|
|
qmax = np.iinfo(dtype).max
|
|
qx = np.round(x / scale + zero_point).astype(np.int64)
|
|
qx = np.clip(qx, qmin, qmax)
|
|
qx = qx.astype(dtype)
|
|
return qx
|
|
|
|
|
|
def _dequantize(qx, scale, zero_point):
|
|
"""Dequantizes a numpy array."""
|
|
x = (qx.astype(float) - zero_point) * scale
|
|
return x
|
|
|
|
|
|
def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
|
|
"""Requantizes a numpy array, i.e., intermediate int32 or int16 values are
|
|
converted back to given type"""
|
|
qx = (x * multiplier).round() + zero_point
|
|
qx = np.clip(qx, qmin, qmax).astype(qtype)
|
|
return qx
|
|
|
|
def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
|
|
"""Calculate the dynamic quantization parameters (scale, zero_point)
|
|
according to the min and max element of the tensor"""
|
|
assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
|
|
if qscheme == torch.per_tensor_symmetric:
|
|
assert dtype == torch.qint8
|
|
if isinstance(X, torch.Tensor):
|
|
X = X.numpy()
|
|
if dtype == torch.qint8:
|
|
if reduce_range:
|
|
qmin, qmax = -64, 63
|
|
else:
|
|
qmin, qmax = -128, 127
|
|
else: # dtype == torch.quint8
|
|
if reduce_range:
|
|
qmin, qmax = 0, 127
|
|
else:
|
|
qmin, qmax = 0, 255
|
|
min_val = X.min()
|
|
max_val = X.max()
|
|
is_symmetric = (qscheme == torch.per_tensor_symmetric)
|
|
if min_val == max_val:
|
|
scale = 1.0
|
|
zero_point = 0
|
|
else:
|
|
if is_symmetric:
|
|
max_val = max(max_val, -min_val)
|
|
min_val = -max_val
|
|
scale = (max_val - min_val) / (qmax - qmin)
|
|
scale = max(scale, np.finfo(np.float32).eps)
|
|
zero_point = 0
|
|
else:
|
|
max_val = max(max_val, 0.0)
|
|
min_val = min(min_val, 0.0)
|
|
scale = (max_val - min_val) / (qmax - qmin)
|
|
scale = max(scale, np.finfo(np.float32).eps)
|
|
zero_point = qmin - round(min_val / scale)
|
|
zero_point = max(qmin, zero_point)
|
|
zero_point = min(qmax, zero_point)
|
|
return [float(scale), int(zero_point)]
|
|
|
|
def _calculate_dynamic_per_channel_qparams(X, dtype):
|
|
"""Calculate the dynamic quantization parameters (scale, zero_point)
|
|
according to the min and max element of the tensor"""
|
|
if isinstance(X, torch.Tensor):
|
|
X = X.numpy()
|
|
qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
|
|
n_levels = qmax - qmin
|
|
scale = np.zeros(X.shape[0], dtype=np.float64)
|
|
zero_point = np.zeros(X.shape[0], dtype=np.int64)
|
|
for i in range(zero_point.shape[0]):
|
|
min_val = X.min()
|
|
max_val = X.max()
|
|
if min_val == max_val:
|
|
scale[i] = 1.0
|
|
zero_point[i] = 0
|
|
else:
|
|
max_val = max(max_val, 0.0)
|
|
min_val = min(min_val, 0.0)
|
|
scale[i] = (max_val - min_val) / n_levels
|
|
scale[i] = max(scale[i], np.finfo(np.float32).eps)
|
|
zero_point[i] = qmin - round(min_val / scale[i])
|
|
zero_point[i] = max(qmin, zero_point[i])
|
|
zero_point[i] = min(qmax, zero_point[i])
|
|
|
|
return scale, zero_point
|
|
|
|
def _snr(x, x_hat):
|
|
"""Calculates the signal to noise ratio and returns the signal and noise
|
|
power, as well as the SNR in dB.
|
|
If the input is a list/tuple this function is called recursively on each
|
|
element. The result will have the same nested structure as the inputs.
|
|
|
|
Args:
|
|
x, x_hat: Either a tensor or a nested list/tuple of tensors.
|
|
Returns:
|
|
signal, noise, SNR(in dB): Either floats or a nested list of floats
|
|
"""
|
|
if isinstance(x, (list, tuple)):
|
|
assert len(x) == len(x_hat)
|
|
res = [_snr(x[idx], x_hat[idx]) for idx in range(len(x))]
|
|
return res
|
|
if x_hat.is_quantized:
|
|
x_hat = x_hat.dequantize()
|
|
if x.is_quantized:
|
|
x = x.dequantize()
|
|
noise = (x - x_hat).norm()
|
|
if noise == 0:
|
|
return 0.0, float('inf'), float('inf')
|
|
signal = x.norm()
|
|
snr = signal / noise
|
|
snr_db = 20 * snr.log10()
|
|
return signal, noise, snr_db
|
|
|
|
@contextmanager
|
|
def override_quantized_engine(qengine):
|
|
previous = torch.backends.quantized.engine
|
|
torch.backends.quantized.engine = qengine
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.backends.quantized.engine = previous
|
|
|
|
@contextmanager
|
|
def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
|
|
try:
|
|
if qengine_is_qnnpack:
|
|
torch._C._set_default_mobile_cpu_allocator()
|
|
yield
|
|
finally:
|
|
if qengine_is_qnnpack:
|
|
torch._C._unset_default_mobile_cpu_allocator()
|
|
|
|
# TODO: Update all quantization tests to use this decorator.
|
|
# Currently for some of the tests it seems to have inconsistent params
|
|
# for fbgemm vs qnnpack.
|
|
def override_qengines(qfunction):
|
|
def test_fn(*args, **kwargs):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
# qfunction should not return anything.
|
|
qfunction(*args, **kwargs)
|
|
return test_fn
|
|
|
|
def qengine_is_fbgemm():
|
|
return torch.backends.quantized.engine == 'fbgemm'
|
|
def qengine_is_qnnpack():
|
|
return torch.backends.quantized.engine == 'qnnpack'
|
|
def qengine_is_onednn():
|
|
return torch.backends.quantized.engine == 'onednn'
|
|
def qengine_is_x86():
|
|
return torch.backends.quantized.engine == 'x86'
|
|
|
|
# Helper function used to simulate per-channel fake-quant against any axis
|
|
def _permute_to_axis_zero(X, axis):
|
|
new_axis_list = list(range(X.dim()))
|
|
new_axis_list[axis] = 0
|
|
new_axis_list[0] = axis
|
|
y = X.permute(tuple(new_axis_list))
|
|
return y, new_axis_list
|
|
|
|
# Reference method for fake quantize
|
|
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
|
|
def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
|
|
dtype = X.dtype
|
|
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
|
|
res = torch.zeros_like(X)
|
|
|
|
for i in range(X.size()[0]):
|
|
res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
|
|
per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
|
|
|
|
out = res.permute(tuple(permute_axis_list))
|
|
return out.to(dtype)
|
|
|
|
# Reference method for the gradient of the fake quantize operator
|
|
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
|
|
def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
|
|
dtype = X.dtype
|
|
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
|
|
Xq = torch.zeros_like(X)
|
|
for i in range(X.size()[0]):
|
|
Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
|
|
Xq = Xq.permute(tuple(permute_axis_list))
|
|
mask = (Xq >= quant_min) * (Xq <= quant_max)
|
|
res = torch.zeros_like(dY)
|
|
res[mask] = dY[mask]
|
|
return res.to(dtype)
|
|
|
|
def to_tensor(X, device):
|
|
if not isinstance(X, torch.Tensor):
|
|
X = torch.tensor(X)
|
|
else:
|
|
X = X.detach().clone()
|
|
return X.to(device=torch.device(device), dtype=torch.float32)
|
|
|
|
# copy-pasted from
|
|
# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
|
|
def _n_ones(n: int) -> int:
|
|
return (1 << n) - 1
|
|
|
|
EBITS_F32, MBITS_F32 = 8, 23
|
|
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
|
|
|
|
# copy-pasted from
|
|
# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
|
|
def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
|
|
"""Convert FP32 numbers to sub-byte floating point numbers with the given
|
|
number of exponent and mantissa bits.
|
|
|
|
Input: torch.Tensor of dtype torch.float
|
|
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
|
|
in the least significant bits. e.g.
|
|
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
|
|
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
|
|
|
|
Note: there are no special values (NaN, inf) support in this code. Values
|
|
outside the representable range of Floatx after rounding are clamped to the
|
|
maximum Floatx magnitude (sign is preserved).
|
|
|
|
Code below is an adaptation of https://fburl.com/code/ciwofcg4
|
|
|
|
Background 1: last answer in https://stackoverflow.com/q/8981913
|
|
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
|
|
"""
|
|
assert x.dtype == torch.float
|
|
assert 1 + ebits + mbits <= 8
|
|
|
|
# calculate constants
|
|
exp_bias = _n_ones(ebits - 1)
|
|
max_int = _n_ones(ebits + mbits)
|
|
sign_mask = 1 << (ebits + mbits)
|
|
|
|
# TODO document this better
|
|
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
|
|
|
|
# all E bits and M bits are 1s
|
|
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
|
|
|
|
# E bits = 1, M bits = 0
|
|
min_normal = 2 ** (1 - exp_bias)
|
|
|
|
denorm_exp = (
|
|
# exp bias conversion between formats
|
|
(F32_EXP_BIAS - exp_bias)
|
|
# mantissa length difference between formats
|
|
+ (MBITS_F32 - mbits)
|
|
# add one to encoded exponent for denormalized numbers
|
|
+ 1
|
|
)
|
|
denorm_mask_int = denorm_exp << MBITS_F32
|
|
|
|
# reinterpret int32 as float32
|
|
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
|
|
torch.float32
|
|
)
|
|
|
|
# save the sign
|
|
# Note that we have torch.uint32, but some ops like cpu bit shifts
|
|
# do not work on it. So, we stay in int32.
|
|
x = x.view(torch.int32)
|
|
sign = x & 0x80000000
|
|
|
|
# set everything to positive, will add sign back at the end
|
|
x = x ^ sign
|
|
|
|
# TODO: can the branch floating point comparisons below be done without
|
|
# converting to float? probably but need to verify
|
|
x = x.view(torch.float)
|
|
|
|
# rewrite saturate/denorm/norm branches without explicit data dependent
|
|
# control flow, to be more compiler friendly
|
|
saturate_mask = x >= max_normal
|
|
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
|
|
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
|
|
|
|
#
|
|
# branch 1: saturate to max val - handled later in the code which combines
|
|
# the branches
|
|
#
|
|
|
|
#
|
|
# branch 2: to conversion to denormal as well as rounding up to normal
|
|
#
|
|
denormal_x = x + denorm_mask_float
|
|
denormal_x = denormal_x.view(torch.int32)
|
|
denormal_x -= denorm_mask_int
|
|
denormal_x = denormal_x.to(torch.uint8)
|
|
|
|
#
|
|
# branch 3: stay in normal range, adjust the exponent and round
|
|
#
|
|
normal_x = x.view(torch.int32)
|
|
# resulting mantissa is odd
|
|
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
|
|
# update exponent, rounding bias part 1
|
|
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
|
|
normal_x += val_to_add
|
|
# rounding bias part 2
|
|
normal_x += mant_odd
|
|
# take the bits!
|
|
normal_x = normal_x >> (MBITS_F32 - mbits)
|
|
normal_x = normal_x.to(torch.uint8)
|
|
|
|
#
|
|
# combine the branches
|
|
#
|
|
x = torch.full_like(x, max_int, dtype=torch.uint8)
|
|
x = torch.where(denormal_mask, denormal_x, x)
|
|
x = torch.where(normal_mask, normal_x, x)
|
|
|
|
# add sign back
|
|
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
|
|
sign_lp = sign_lp.to(torch.uint8)
|
|
# Right shift of a negative signed integer can fill the least significant
|
|
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
|
|
# doesn't have an uint32 dtype, we mask out these bits to get just the
|
|
# f4 sign bit
|
|
sign_lp = sign_lp & sign_mask
|
|
x = x | sign_lp
|
|
|
|
return x.to(torch.uint8)
|
|
|
|
|
|
# copy-pasted from
|
|
# https://github.com/pytorch/ao/blob/29488018d99af7f7339f06353c6b5bbeae8a1493/torchao/prototype/custom_fp_utils.py#L147
|
|
def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
|
|
"""Convert sub-byte floating point numbers with the given number of exponent
|
|
and mantissa bits to FP32.
|
|
|
|
Input: torch.Tensor of dtype uint8, where the bit encoding is stored
|
|
in the least significant bits. e.g.
|
|
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
|
|
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
|
|
Output: torch.Tensor of dtype fp32 with the dequantized value
|
|
"""
|
|
assert x.dtype == torch.uint8
|
|
assert 1 + ebits + mbits <= 8
|
|
|
|
sign_mask = 1 << (ebits + mbits)
|
|
exp_bias = _n_ones(ebits - 1)
|
|
mantissa_mask = _n_ones(mbits)
|
|
|
|
# save the sign
|
|
sign_lp = x & sign_mask
|
|
|
|
# set everything to positive, will add sign back at the end
|
|
x_pos = x ^ sign_lp
|
|
|
|
#
|
|
# 1. Calculate zero mask
|
|
#
|
|
zero_mask = x_pos == 0
|
|
|
|
#
|
|
# 2. Calculate the denormal path mask
|
|
#
|
|
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
|
|
|
|
#
|
|
# 3. Calculate the normal path
|
|
#
|
|
|
|
# calculate the new exponent and shift it to bits 2:9 of the result
|
|
exp_biased_lp = x_pos >> mbits
|
|
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
|
|
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
|
|
|
|
# shift the mantissa to bits 10:32 of the result
|
|
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
|
|
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
|
|
result = exp_biased_f32 | mantissa_f32
|
|
|
|
#
|
|
# 4. Add the zero and denormal casts to the already casted normal path
|
|
#
|
|
result[zero_mask] = 0
|
|
|
|
denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
|
|
|
|
# fast path.
|
|
# without this, performance for FP4_E2M1 is slower by 2x
|
|
if mbits == 1:
|
|
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
|
|
|
|
else:
|
|
# iterate over all possible values of mantissa
|
|
# i=0, j=1
|
|
# i=1, j=10,11
|
|
# i=2, j=100,101,110,111
|
|
# and so on
|
|
for i in range(mbits):
|
|
for mantissa_cmp in range(1 << i, 1 << (i + 1)):
|
|
# left shift mantissa until it overflows (create an implicit 1)
|
|
# subtract exponent by the same amount
|
|
left_shift = mbits - i
|
|
mantissa_f32 = (mantissa_cmp - (1 << i)) << (
|
|
left_shift + MBITS_F32 - mbits
|
|
)
|
|
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
|
|
|
|
# we can update this in-place since the values won't overlap
|
|
# torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
|
|
# thus we use + instead of | here
|
|
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = (
|
|
exp_biased_f32 + mantissa_f32
|
|
)
|
|
|
|
result = torch.where(denormal_mask, mantissa_lp_int32, result)
|
|
|
|
# add sign back
|
|
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
|
|
result = result | sign_f32
|
|
|
|
return result.view(torch.float)
|
|
|
|
# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
def to_blocked(input_matrix) -> torch.Tensor:
|
|
"""
|
|
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
|
|
|
See:
|
|
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
|
|
|
Args:
|
|
input_matrix: Input tensor of shape (H, W)
|
|
|
|
Returns:
|
|
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
|
"""
|
|
rows, cols = input_matrix.shape
|
|
n_row_blocks = ceil_div(rows, 128)
|
|
n_col_blocks = ceil_div(cols, 4)
|
|
|
|
# Calculate the padded shape
|
|
padded_rows = n_row_blocks * 128
|
|
padded_cols = n_col_blocks * 4
|
|
|
|
padded = input_matrix
|
|
# Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
|
|
if (rows, cols) != (padded_rows, padded_cols):
|
|
padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
|
|
padded[:rows, :cols] = input_matrix
|
|
|
|
# Rearrange the blocks
|
|
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
|
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
|
|
|
return rearranged.flatten()
|