Add dequantize_linear for JIT pass (#20107)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20107

att

Reviewed By: nishantpdce

Differential Revision: D15202187

fbshipit-source-id: 7d6274a67fcca695c0425587f35046fecbc2ccdc
This commit is contained in:
Jerry Zhang
2019-05-21 12:15:44 -07:00
committed by Facebook Github Bot
parent cc02a1af61
commit cca923c481
10 changed files with 86 additions and 7 deletions

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/macros/Macros.h>
#include <c10/util/Half.h>
#include <c10/util/Exception.h>
#include <ATen/core/DeprecatedTypeProperties.h>
@ -11,6 +12,13 @@
return __VA_ARGS__(); \
}
#define AT_QINT_PRIVATE_CASE_TYPE(enum_type, type, underlying_type, ...) \
case enum_type: { \
using scalar_t C10_UNUSED = type; \
using underlying_t C10_UNUSED = underlying_type; \
return __VA_ARGS__(); \
}
namespace detail {
template <at::ScalarType N>
@ -211,14 +219,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
[&] { \
switch (TYPE) { \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::QInt8, qint8, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::QUInt8, quint8, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::QInt32, qint32, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
at::ScalarType::QInt8, qint8, int8_t, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
at::ScalarType::QUInt8, quint8, uint8_t, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
at::ScalarType::QInt32, qint32, int, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
}()

View File

@ -583,6 +583,7 @@ class CAFFE2_API Tensor {
Tensor to_mkldnn() const;
Tensor quantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
Tensor dequantize() const;
Tensor dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
Scalar q_scale() const;
Scalar q_zero_point() const;
Tensor int_repr() const;

View File

@ -804,6 +804,9 @@ inline Tensor Tensor::quantize_linear(double scale, int64_t zero_point, ScalarTy
inline Tensor Tensor::dequantize() const {
return dispatch_type().dequantize(*this);
}
inline Tensor Tensor::dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const {
return dispatch_type().dequantize_linear(*this, scale, zero_point, dtype);
}
inline Scalar Tensor::q_scale() const {
return dispatch_type().q_scale(*this);
}

View File

@ -393,6 +393,7 @@ struct CAFFE2_API Type {
virtual Tensor to_mkldnn(const Tensor & self) const = 0;
virtual Tensor quantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
virtual Tensor dequantize(const Tensor & self) const = 0;
virtual Tensor dequantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
virtual Scalar q_scale(const Tensor & self) const = 0;
virtual Scalar q_zero_point(const Tensor & self) const = 0;
virtual Tensor int_repr(const Tensor & self) const = 0;

View File

@ -2551,6 +2551,11 @@
dispatch:
QuantizedCPU: dequantize_quant
- func: dequantize_linear(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor
variants: function, method
dispatch:
CPU: dequantize_linear_cpu
- func: q_scale(Tensor self) -> Scalar
variants: function, method
dispatch:

View File

@ -16,6 +16,21 @@ Tensor dequantize_quant(const Tensor& self) {
return get_qtensorimpl(self)->quantizer()->dequantize(self);
}
Tensor dequantize_linear_cpu(const Tensor& self, double scale, int64_t zero_point, ScalarType dtype) {
AT_CHECK(isQIntType(toQIntType(self.scalar_type())),
"Scalar type for quantized Tensor must have same underlying type as input.");
AT_CHECK(dtype == ScalarType::Float, "ScalarType for target Tensor must be float.");
Tensor f = at::empty(self.sizes(), self.options().dtype(dtype));
AT_DISPATCH_QINT_TYPES(
toQIntType(self.scalar_type()), "dequantize_linear_cpu", [&]() {
underlying_t* qdata = self.data<underlying_t>();
auto* fdata = f.data<float>();
for (int i = 0; i < self.numel(); ++i) {
fdata[i] = (static_cast<float>(qdata[i]) - zero_point) * scale;
}});
return f;
}
Scalar q_scale_quant(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);