mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change Weight to QTensor with qint8(int8_t) (#20712)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20712 As Title says. Differential Revision: D15410696 fbshipit-source-id: 48147a79d8cc47a724eb473796a37a1c64f8e883
This commit is contained in:
committed by
Facebook Github Bot
parent
ac2314fdeb
commit
b9a150ede0
@ -45,16 +45,13 @@ class QFCPackWeightInt8 final : public c10::OperatorKernel {
|
||||
auto N = weight.size(0);
|
||||
auto K = weight.size(1);
|
||||
|
||||
int32_t weight_zero_point_int32 = weight.q_zero_point().toInt() - 128;
|
||||
int32_t weight_zero_point_int32 = weight.q_zero_point().toInt();
|
||||
|
||||
// TODO: contiguous is called for further JIT optimizations.
|
||||
auto weight_contig = weight.contiguous();
|
||||
|
||||
std::vector<int8_t> weight_int8(K * N);
|
||||
int8_t* weight_ptr_int8 = weight_int8.data();
|
||||
uint8_t* weight_ptr_uint8 =
|
||||
reinterpret_cast<uint8_t*>(weight_contig.data<c10::quint8>());
|
||||
convert_uint8_int8(K, N, weight_ptr_uint8, weight_ptr_int8);
|
||||
int8_t* weight_ptr_int8 =
|
||||
reinterpret_cast<int8_t*>(weight_contig.data<c10::qint8>());
|
||||
|
||||
std::vector<int32_t> col_offsets(N);
|
||||
calc_col_offsets_transpose(
|
||||
|
@ -218,7 +218,6 @@ class TestQuantizedFC(unittest.TestCase):
|
||||
).astype(np.uint8)
|
||||
|
||||
W_scale = 0.4
|
||||
# W_zp is the zero point for int8 quantization.
|
||||
W_zp = 2
|
||||
W_value_min = -128
|
||||
W_value_max = 127
|
||||
@ -244,8 +243,7 @@ class TestQuantizedFC(unittest.TestCase):
|
||||
W = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)
|
||||
|
||||
X_q = X.quantize_linear(scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
|
||||
# W_zp + 128 is the zero point for uint8 quantization.
|
||||
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp + 128, dtype=torch.quint8)
|
||||
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp, dtype=torch.qint8)
|
||||
b_q = torch.round(torch.rand(output_channels) * 10 - 10).to(dtype=torch.int32)
|
||||
|
||||
# Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with
|
||||
@ -322,7 +320,7 @@ class TestQuantizedFC(unittest.TestCase):
|
||||
W = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)
|
||||
|
||||
X_q = X.quantize_linear(scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
|
||||
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp + 128, dtype=torch.quint8)
|
||||
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp, dtype=torch.qint8)
|
||||
b_q = torch.round(torch.rand(output_channels) * 10 - 10).to(dtype=torch.int32)
|
||||
|
||||
# Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with
|
||||
|
Reference in New Issue
Block a user