Change Weight to QTensor with qint8(int8_t) (#20712)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20712

As Title says.

Differential Revision: D15410696

fbshipit-source-id: 48147a79d8cc47a724eb473796a37a1c64f8e883
This commit is contained in:
Jianyu Huang
2019-05-21 12:32:59 -07:00
committed by Facebook Github Bot
parent ac2314fdeb
commit b9a150ede0
2 changed files with 5 additions and 10 deletions

View File

@ -45,16 +45,13 @@ class QFCPackWeightInt8 final : public c10::OperatorKernel {
auto N = weight.size(0);
auto K = weight.size(1);
int32_t weight_zero_point_int32 = weight.q_zero_point().toInt() - 128;
int32_t weight_zero_point_int32 = weight.q_zero_point().toInt();
// TODO: contiguous is called for further JIT optimizations.
auto weight_contig = weight.contiguous();
std::vector<int8_t> weight_int8(K * N);
int8_t* weight_ptr_int8 = weight_int8.data();
uint8_t* weight_ptr_uint8 =
reinterpret_cast<uint8_t*>(weight_contig.data<c10::quint8>());
convert_uint8_int8(K, N, weight_ptr_uint8, weight_ptr_int8);
int8_t* weight_ptr_int8 =
reinterpret_cast<int8_t*>(weight_contig.data<c10::qint8>());
std::vector<int32_t> col_offsets(N);
calc_col_offsets_transpose(

View File

@ -218,7 +218,6 @@ class TestQuantizedFC(unittest.TestCase):
).astype(np.uint8)
W_scale = 0.4
# W_zp is the zero point for int8 quantization.
W_zp = 2
W_value_min = -128
W_value_max = 127
@ -244,8 +243,7 @@ class TestQuantizedFC(unittest.TestCase):
W = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)
X_q = X.quantize_linear(scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
# W_zp + 128 is the zero point for uint8 quantization.
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp + 128, dtype=torch.quint8)
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp, dtype=torch.qint8)
b_q = torch.round(torch.rand(output_channels) * 10 - 10).to(dtype=torch.int32)
# Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with
@ -322,7 +320,7 @@ class TestQuantizedFC(unittest.TestCase):
W = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)
X_q = X.quantize_linear(scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp + 128, dtype=torch.quint8)
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp, dtype=torch.qint8)
b_q = torch.round(torch.rand(output_channels) * 10 - 10).to(dtype=torch.int32)
# Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with